Files
better-genshin-impact/Test/BetterGenshinImpact.UnitTest/GameTaskTests/AutoFishingTests/RodNetTests.cs
FishmanTheMurloc 4787f4f2f9 尝试复现RodNet的训练结果 (#1688)
* 增加RodNet必须支持训练的单元测试;RodNet的torch链路改造成全张量计算,由此把之前忽略的参数dz、h_coeff、offset变得可学习

* 了解到损失函数CrossEntropyLoss内置了softmax,因此从forward方法中移出;offset是手动指定的偏置值,不是学习得到的,因此也移出到单独的PostProcess方法中

* 根据得到的源码整理RodNet;新增在数据集上达到一定准确率的单元测试
2025-06-14 20:03:14 +08:00

54 lines
2.0 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
using BetterGenshinImpact.GameTask.AutoFishing;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using TorchSharp;
using static TorchSharp.torch;
namespace BetterGenshinImpact.UnitTest.GameTaskTests.AutoFishingTests
{
[Collection("Init Collection")]
public partial class RodNetTests
{
public RodNetTests(TorchFixture torch)
{
if (!torch.UseTorch)
throw new NotSupportedException("torch加载失败请检查BetterGenshinImpact项目编译环境的配置");
}
[Theory]
[InlineData(517.6326F, 548.49023F, 255.25723F, 263.55743F, 256.57538F, 351.56964F, 274.65656F, 333.1523F, 5)]
/// <summary>
/// 测试计算给到后处理之前的浮点数输出Torch推理的结果和直接用数学计算的结果两者的数值应该在转换到单精度时相同
/// </summary>
public void ComputeScoresTest_ShouldBeTheSame(double rod_x1, double rod_x2, double rod_y1, double rod_y2, double fish_x1, double fish_x2, double fish_y1, double fish_y2, int fish_label)
{
//
RodInput rodInput = new RodInput
{
rod_x1 = rod_x1,
rod_x2 = rod_x2,
rod_y1 = rod_y1,
rod_y2 = rod_y2,
fish_x1 = fish_x1,
fish_x2 = fish_x2,
fish_y1 = fish_y1,
fish_y2 = fish_y2,
fish_label = fish_label
};
RodNet sut = new RodNet();
//
Tensor outputTensor = sut.ComputeScores_Torch(rodInput);
double[] pred = RodNet.ComputeScores(rodInput);
//
Assert.Equal((float)pred[0], (float)outputTensor.data<double>()[0]); // 对比时降低精度,差不多就行
Assert.Equal((float)pred[1], (float)outputTensor.data<double>()[1]);
Assert.Equal((float)pred[2], (float)outputTensor.data<double>()[2]);
}
}
}