mirror of
https://github.com/babalae/better-genshin-impact.git
synced 2026-03-30 10:19:51 +08:00
* 增加RodNet必须支持训练的单元测试;RodNet的torch链路改造成全张量计算,由此把之前忽略的参数dz、h_coeff、offset变得可学习 * 了解到损失函数CrossEntropyLoss内置了softmax,因此从forward方法中移出;offset是手动指定的偏置值,不是学习得到的,因此也移出到单独的PostProcess方法中 * 根据得到的源码整理RodNet;新增在数据集上达到一定准确率的单元测试
54 lines
2.0 KiB
C#
54 lines
2.0 KiB
C#
using BetterGenshinImpact.GameTask.AutoFishing;
|
||
using System;
|
||
using System.Collections.Generic;
|
||
using System.Linq;
|
||
using System.Text;
|
||
using System.Threading.Tasks;
|
||
using TorchSharp;
|
||
using static TorchSharp.torch;
|
||
|
||
namespace BetterGenshinImpact.UnitTest.GameTaskTests.AutoFishingTests
|
||
{
|
||
[Collection("Init Collection")]
|
||
public partial class RodNetTests
|
||
{
|
||
public RodNetTests(TorchFixture torch)
|
||
{
|
||
if (!torch.UseTorch)
|
||
throw new NotSupportedException("torch加载失败,请检查BetterGenshinImpact项目编译环境的配置");
|
||
}
|
||
|
||
[Theory]
|
||
[InlineData(517.6326F, 548.49023F, 255.25723F, 263.55743F, 256.57538F, 351.56964F, 274.65656F, 333.1523F, 5)]
|
||
/// <summary>
|
||
/// 测试计算给到后处理之前的浮点数输出,Torch推理的结果和直接用数学计算的结果,两者的数值应该在转换到单精度时相同
|
||
/// </summary>
|
||
public void ComputeScoresTest_ShouldBeTheSame(double rod_x1, double rod_x2, double rod_y1, double rod_y2, double fish_x1, double fish_x2, double fish_y1, double fish_y2, int fish_label)
|
||
{
|
||
//
|
||
RodInput rodInput = new RodInput
|
||
{
|
||
rod_x1 = rod_x1,
|
||
rod_x2 = rod_x2,
|
||
rod_y1 = rod_y1,
|
||
rod_y2 = rod_y2,
|
||
fish_x1 = fish_x1,
|
||
fish_x2 = fish_x2,
|
||
fish_y1 = fish_y1,
|
||
fish_y2 = fish_y2,
|
||
fish_label = fish_label
|
||
};
|
||
RodNet sut = new RodNet();
|
||
|
||
//
|
||
Tensor outputTensor = sut.ComputeScores_Torch(rodInput);
|
||
double[] pred = RodNet.ComputeScores(rodInput);
|
||
|
||
//
|
||
Assert.Equal((float)pred[0], (float)outputTensor.data<double>()[0]); // 对比时降低精度,差不多就行
|
||
Assert.Equal((float)pred[1], (float)outputTensor.data<double>()[1]);
|
||
Assert.Equal((float)pred[2], (float)outputTensor.data<double>()[2]);
|
||
}
|
||
}
|
||
}
|