diff --git a/BetterGenshinImpact/GameTask/AutoFishing/RodInput.cs b/BetterGenshinImpact/GameTask/AutoFishing/RodInput.cs index cb33932e..597d473c 100644 --- a/BetterGenshinImpact/GameTask/AutoFishing/RodInput.cs +++ b/BetterGenshinImpact/GameTask/AutoFishing/RodInput.cs @@ -1,4 +1,6 @@ -namespace BetterGenshinImpact.GameTask.AutoFishing; +using static TorchSharp.torch; + +namespace BetterGenshinImpact.GameTask.AutoFishing; public record RodInput { diff --git a/BetterGenshinImpact/GameTask/AutoFishing/RodNet.cs b/BetterGenshinImpact/GameTask/AutoFishing/RodNet.cs index 86c1656e..3934977b 100644 --- a/BetterGenshinImpact/GameTask/AutoFishing/RodNet.cs +++ b/BetterGenshinImpact/GameTask/AutoFishing/RodNet.cs @@ -37,8 +37,14 @@ namespace BetterGenshinImpact.GameTask.AutoFishing; /// tmd 今天我意识到 XXXX可不就是XXXX /// /// 哦 到这一步以后剩下的就很弱智了 远了挪近一点 近了挪远一点 调调参差不多得了 +/// +/// *后来又新增了一些访谈内容: +/// +/// 额 总之就是要求不能把不咬钩的识别成咬钩的 但是咬钩的可以识别成不咬钩的 +/// +/// 然后就可视化一下onehot在不同距离的结果 加一个offset使得模型输出的结果在保证可以predict距离正好的结果的同时距离范围尽可能小 /// -public class RodNet : Module +public class RodNet : Module { const double alpha = 1734.34 / 2.5; // fitted parameters @@ -75,23 +81,21 @@ public class RodNet : Module static readonly double[] offset = { 0.8, 0.4, 0.35, 0.35, 0.6, 0.3, 0.3, 0.8, 0.8, 0.8, 0.8 }; - private readonly Module layers; + private Parameter thetaParameter; + private Parameter bParameter; + private Parameter dzParameter; + private Parameter hCoeffParameter; public RodNet() : base("RodNet") { - var weight = tensor(RodNet.weight, ScalarType.Float64); - var bias = tensor(RodNet.bias, ScalarType.Float64); + long num_embeddings = RodNet.weight.GetLength(0); + long embedding_dim = 3; - RodLayer1 rodLayer1 = new RodLayer1(num_embeddings: weight.shape[0], embedding_dim: weight.shape[1], input_dim: 3, output_dim: 3); - rodLayer1.SetWeightsManually(weight, bias); + this.thetaParameter = new Parameter(torch.randn(num_embeddings, embedding_dim, dtype: ScalarType.Float64)); + this.bParameter = new Parameter(torch.randn(num_embeddings, embedding_dim, dtype: ScalarType.Float64)); - var modules = new List<(string, Module)> - { - ($"rodLayer1", rodLayer1), - ($"softmax", nn.Softmax(1)) - }; - - layers = Sequential(modules); + this.dzParameter = new Parameter(torch.zeros(num_embeddings, 1, dtype: ScalarType.Float64)); + this.hCoeffParameter = new Parameter(torch.zeros(num_embeddings, 1, dtype: ScalarType.Float64)); RegisterComponents(); } @@ -109,58 +113,26 @@ public class RodNet : Module dst[i] /= sum; } } - public record NetInput(double dist, int fish_label); - public static NetInput? GeometryProcessing(RodInput input) + + internal static int GetRodState(RodInput input) { - double a, b, v0, u, v, h; + double[] pred = ComputeScores(input); - a = (input.rod_x2 - input.rod_x1) / 2 / alpha; - b = (input.rod_y2 - input.rod_y1) / 2 / alpha; - h = (input.fish_y2 - input.fish_y1) / 2 / alpha; + return Array.IndexOf(pred, pred.Max()); + } - if (a < b) - { - b = Math.Sqrt(a * b); - a = b + 1e-6; - } + public static double[] ComputeScores(RodInput input) + { + var (y0, z0, t, u, v, h) = GetRodStatePreProcess(input); - v0 = (288 - (input.rod_y1 + input.rod_y2) / 2) / alpha; - - u = (input.fish_x1 + input.fish_x2 - input.rod_x1 - input.rod_x2) / 2 / alpha; - v = (288 - (input.fish_y1 + input.fish_y2) / 2) / alpha; v -= h * h_coeff[input.fish_label]; - - double y0, z0, t; double x, y, dist; - y0 = Math.Sqrt(Math.Pow(a, 4) - b * b + a * a * (1 - b * b + v0 * v0)) / (a * a); - z0 = b / (a * a); - t = a * a * (y0 * b + v0) / (a * a - b * b); - x = u * (z0 + dz[input.fish_label]) * Math.Sqrt(1 + t * t) / (t - v); y = (z0 + dz[input.fish_label]) * (1 + t * v) / (t - v); dist = Math.Sqrt(x * x + (y - y0) * (y - y0)); - return new NetInput(dist, input.fish_label); - } - - internal static int GetRodState(RodInput input) - { - NetInput? netInput = GeometryProcessing(input); - if (netInput is null) - { - return -1; - } - - double[] pred = ComputeScores(netInput); - - return Array.IndexOf(pred, pred.Max()); - } - - public static double[] ComputeScores(NetInput netInput) - { - double dist = netInput.dist; - int fish_label = netInput.fish_label; + int fish_label = input.fish_label; double[] logits = new double[3]; for (int i = 0; i < 3; i++) @@ -177,73 +149,113 @@ public class RodNet : Module internal int GetRodState_Torch(RodInput input) { - NetInput? netInput = GeometryProcessing(input); - if (netInput is null) - { - return -1; - } - - Tensor outputTensor = ComputeScores_Torch(netInput); + using var _ = no_grad(); + Tensor outputTensor = ComputeScores_Torch(input); var max = argmax(outputTensor); return (int)max.item(); } - public Tensor ComputeScores_Torch(NetInput netInput) + public Tensor ComputeScores_Torch(RodInput input) { - double dist = netInput.dist; - int fish_label = netInput.fish_label; + using var _ = no_grad(); + this.SetWeightsManually(); - Tensor inputTensor = cat([tensor(new double[,] { { dist } }, dtype: ScalarType.Float64), - tensor(new int[,] { {fish_label } }, dtype: ScalarType.Int32)]).T; - var outputTensor = forward(inputTensor); + var (y0, z0, t, u, v, h) = GetRodStatePreProcess(input); - outputTensor[0][0] = outputTensor[0][0] - RodNet.offset[fish_label]; + Tensor fishLabel = tensor(new double[] { input.fish_label }, dtype: ScalarType.Int32); + Tensor uv = tensor(new double[,] { { u, v } }, dtype: ScalarType.Float64); + Tensor y0z0t = tensor(new double[,] { { y0, z0, t } }, dtype: ScalarType.Float64); + Tensor h_ = tensor(new double[,] { { h } }, dtype: ScalarType.Float64); - return outputTensor; + var logits = forward(fishLabel, uv, y0z0t, h_); + var output = PostProcess(logits, fishLabel); + + return output; } - public override Tensor forward(Tensor input) + /// + /// 使用时直接赋值已知权重 + /// + public void SetWeightsManually() { - return layers.forward(input); - } -} - -public class RodLayer1 : Module -{ - private readonly Embedding embedding1; - private readonly Embedding embedding2; - private readonly Linear linear; - public RodLayer1(long num_embeddings, long embedding_dim, long input_dim, long output_dim) - : base("RodLinear") - { - embedding1 = torch.nn.Embedding(num_embeddings, embedding_dim); - embedding2 = torch.nn.Embedding(num_embeddings, embedding_dim); - linear = torch.nn.Linear(input_dim, output_dim); - - RegisterComponents(); + var weightTensor = tensor(RodNet.weight, ScalarType.Float64); + var biasTensor = tensor(RodNet.bias, ScalarType.Float64); + var dzTensor = tensor(RodNet.dz, ScalarType.Float64).reshape([RodNet.dz.Length, 1]); + var h_coeffTensor = tensor(RodNet.h_coeff, ScalarType.Float64).reshape([RodNet.h_coeff.Length, 1]); + this.thetaParameter = new Parameter(weightTensor); + this.bParameter = new Parameter(biasTensor); + this.dzParameter = new Parameter(dzTensor); + this.hCoeffParameter = new Parameter(h_coeffTensor); } - public void SetWeightsManually(Tensor weight, Tensor bias) + public override Tensor forward(Tensor fishLabel, Tensor uv, Tensor y0z0t, Tensor h) { - embedding1.weight = new Parameter(weight); - embedding2.weight = new Parameter(bias); + var uvSplit = uv.split([1, 1], dim: 1); + Tensor u = uvSplit[0]; + Tensor v = uvSplit[1]; + + var y0z0tSplit = y0z0t.split([1, 1, 1], dim: 1); + Tensor y0 = y0z0tSplit[0]; + Tensor z0 = y0z0tSplit[1]; + Tensor t = y0z0tSplit[2]; + + v = v - h * hCoeffParameter[fishLabel]; + + Tensor x, y, dist; + + var dz = dzParameter[fishLabel]; + x = u * (z0 + dz) * torch.sqrt(1 + t * t) / (t - v); + y = (z0 + dz) * (1 + t * v) / (t - v); + dist = torch.sqrt(x * x + (y - y0) * (y - y0)); + + Tensor logits = this.thetaParameter[fishLabel] * dist + this.bParameter[fishLabel]; + + return logits; } - public override Tensor forward(Tensor input) + public Tensor PostProcess(Tensor logits, Tensor fishLabel) { - var splitInput = input.split([1, 1], dim: 1); - var dist = splitInput[0]; - var fish_label = splitInput[1].to(ScalarType.Int32).flatten(); + var x_softmax = torch.nn.functional.softmax(logits, 1); - var embed1 = embedding1.forward(fish_label); - //Console.WriteLine(String.Join(",", embed1.data())); - var embed2 = embedding2.forward(fish_label); - //Console.WriteLine(String.Join(",", embed2.data())); + Tensor x_offset = tensor(fishLabel.data().Select(l => RodNet.offset[l]).ToArray()); - linear.weight = new Parameter(embed1.T); - linear.bias = new Parameter(embed2); + x_softmax[torch.arange(x_offset.shape[0]), 0] -= x_offset; + return x_softmax; + } - return linear.forward(dist); + /// + /// 根据rod和fish的坐标计算y0z0t、uv、h + /// + /// + /// y0, z0, t, u, v, h + public static (double, double, double, double, double, double) GetRodStatePreProcess(RodInput input) + { + /* + * 以下为hutaofisher代码中关于部分变量的意义的注释 + # uv: screen coordinate of bbox center of the fish + # abv0: rod shape and center coordinate in screen + */ + double a, b, v0, u, v, h; + + a = (input.rod_x2 - input.rod_x1) / 2 / alpha; + b = (input.rod_y2 - input.rod_y1) / 2 / alpha; + h = (input.fish_y2 - input.fish_y1) / 2 / alpha; + + if (a < b) + { + b = Math.Sqrt(a * b); + a = b + 1e-6; + } + v0 = (288 - (input.rod_y1 + input.rod_y2) / 2) / alpha; + u = (input.fish_x1 + input.fish_x2 - input.rod_x1 - input.rod_x2) / 2 / alpha; + v = (288 - (input.fish_y1 + input.fish_y2) / 2) / alpha; + double y0, z0, t; + + y0 = Math.Sqrt(Math.Pow(a, 4) - b * b + a * a * (1 - b * b + v0 * v0)) / (a * a); + z0 = b / (a * a); + t = a * a * (y0 * b + v0) / (a * a - b * b); + + return (y0, z0, t, u, v, h); } } \ No newline at end of file diff --git a/Test/BetterGenshinImpact.UnitTest/GameTaskTests/AutoFishingTests/RodNetTests.Training.cs b/Test/BetterGenshinImpact.UnitTest/GameTaskTests/AutoFishingTests/RodNetTests.Training.cs new file mode 100644 index 00000000..86a8ed1d --- /dev/null +++ b/Test/BetterGenshinImpact.UnitTest/GameTaskTests/AutoFishingTests/RodNetTests.Training.cs @@ -0,0 +1,392 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using static TorchSharp.torch.nn; +using static TorchSharp.torch; +using TorchSharp; +using System.Diagnostics; +using System.Collections; +using BetterGenshinImpact.GameTask.AutoFishing; + +namespace BetterGenshinImpact.UnitTest.GameTaskTests.AutoFishingTests +{ + public partial class RodNetTests + { + /// + /// RodNet验证,应在数据集上达到一定准确率 + /// + [Theory] + [InlineData(@"..\..\..\Assets\AutoFishing\data_selected.csv")] + public void Training_AccuracyShouldBeOK(string dataLocation) + { + // + using var _ = no_grad(); + + var device = + torch.cuda.is_available() ? torch.CUDA : + torch.mps_is_available() ? torch.MPS : + torch.CPU; + var loss = CrossEntropyLoss(); + var sut = new RodNet().to((Device)device); + sut.SetWeightsManually(); + + using var test_reader = new CSVReader(Enumerable.Repeat(false, 8).Concat(Enumerable.Repeat(true, 2)), Path.GetFullPath(dataLocation), (Device)device); + + // + var accuracy = evaluate(test_reader.GetBatches(eval_batch_size), sut, loss); + + // + Assert.True(accuracy > 0.8); + } + + /// + /// RodNet必须粗略地支持训练 + /// + [Fact] + public void Training_ShouldBeDifferentiable() + { + // + RodInput input = new RodInput(); + var (y0, z0, t, u, v, h) = RodNet.GetRodStatePreProcess(input); + + Tensor fishLabel = tensor(new double[] { input.fish_label }, dtype: ScalarType.Int32); + Tensor uv = tensor(new double[,] { { u, v } }, dtype: ScalarType.Float64); + Tensor y0z0t = tensor(new double[,] { { y0, z0, t } }, dtype: ScalarType.Float64); + Tensor h_ = tensor(new double[,] { { h } }, dtype: ScalarType.Float64); + RodNet sut = new RodNet(); + + // + Tensor output = sut.forward(fishLabel, uv, y0z0t, h_); + output.backward([torch.ones_like(output)]); + // + } + + #region 训练相关代码 + // 这部分代码改编自TorchSharpExamples的CSharpExamples.TextClassification + private const long batch_size = 32; + private const long eval_batch_size = 32; + + internal static RodNet Run(int epochs, int timeout, string dataLocation) + { + torch.random.manual_seed(1); + + var device = + torch.cuda.is_available() ? torch.CUDA : + torch.mps_is_available() ? torch.MPS : + torch.CPU; + + Console.WriteLine(); + Console.WriteLine($"\tRunning TextClassification on {device.type.ToString()} for {epochs} epochs, terminating after {TimeSpan.FromSeconds(timeout)}."); + Console.WriteLine(); + + Console.WriteLine($"\tPreparing training and test data..."); + + using (var reader = new CSVReader(Enumerable.Repeat(true, 8).Concat(Enumerable.Repeat(false, 2)), dataLocation, (Device)device)) + { + Console.WriteLine($"\tCreating the model..."); + Console.WriteLine(); + + var model = new RodNet().to((Device)device); + + var loss = CrossEntropyLoss(); + var lr = 1e-2; + var optimizer = torch.optim.SGD(model.parameters(), learningRate: lr); + var scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min: 0); + + var totalTime = new Stopwatch(); + totalTime.Start(); + + foreach (var epoch in Enumerable.Range(1, epochs)) + { + + var sw = new Stopwatch(); + sw.Start(); + + train(epoch, reader.GetBatches(batch_size), model, loss, optimizer); + + sw.Stop(); + + Console.WriteLine($"\nEnd of epoch: {epoch} | lr: {optimizer.ParamGroups.First().LearningRate:0.0000} | time: {sw.Elapsed.TotalSeconds:0.0}s\n"); + scheduler.step(); + + if (totalTime.Elapsed.TotalSeconds > timeout) break; + } + + totalTime.Stop(); + + using (var test_reader = new CSVReader(Enumerable.Repeat(false, 8).Concat(Enumerable.Repeat(true, 2)), dataLocation, (Device)device)) + { + + var sw = new Stopwatch(); + sw.Start(); + + var accuracy = evaluate(test_reader.GetBatches(eval_batch_size), model, loss); + + sw.Stop(); + + Console.WriteLine($"\nEnd of training: test accuracy: {accuracy:0.00} | eval time: {sw.Elapsed.TotalSeconds:0.0}s\n"); + scheduler.step(); + } + + foreach (var (name, param) in model.named_parameters()) + { + switch (param.dtype) + { + case ScalarType.Int64: + Console.WriteLine($"参数{name}={String.Join(", ", param.data())}"); + break; + case ScalarType.Float32: + Console.WriteLine($"参数{name}={String.Join(", ", param.data())}"); + break; + case ScalarType.Float64: + Console.WriteLine($"参数{name}={String.Join(", ", param.data())}"); + break; + } + } + + return model; + } + + } + + static void train(int epoch, IEnumerable<(Tensor, Tensor, Tensor, Tensor, Tensor)> train_data, RodNet model, Loss criterion, torch.optim.Optimizer optimizer) + { + model.train(); + + double total_acc = 0.0; + long total_count = 0; + long log_interval = 1; + + var batch = 0; + + var batch_count = train_data.Count(); + + using (var d = torch.NewDisposeScope()) + { + foreach (var (y0z0t, uv, h, fish_label, success) in train_data) + { + + optimizer.zero_grad(); + + using (var predicted_labels = model.forward(fish_label, uv, y0z0t, h)) + { + var loss = criterion.forward(predicted_labels, success.to(ScalarType.Int64)); + loss.backward(); + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1); + optimizer.step(); + + total_acc += (predicted_labels.argmax(1) == success).sum().to(torch.CPU).item(); + total_count += success.size(0); + } + + batch += 1; + if (batch % log_interval == 0) + { + var accuracy = total_acc / total_count; + Console.WriteLine($"epoch: {epoch} | batch: {batch} / {batch_count} | accuracy: {accuracy:0.00}"); + } + } + } + } + + static double evaluate(IEnumerable<(Tensor, Tensor, Tensor, Tensor, Tensor)> test_data, RodNet model, Loss criterion) + { + model.eval(); + + double total_acc = 0.0; + long total_count = 0; + + using (var d = torch.NewDisposeScope()) + { + foreach (var (y0z0t, uv, h, fish_label, success) in test_data) + { + using (var predicted_labels = model.forward(fish_label, uv, y0z0t, h)) + { + var loss = criterion.forward(predicted_labels, success.to(ScalarType.Int64)); + + total_acc += (predicted_labels.argmax(1) == success).sum().to(torch.CPU).item(); + total_count += success.size(0); + } + } + + return total_acc / total_count; + } + } + #endregion + } + + internal class CSVReader : IDisposable + { + /// + /// + /// + /// 按长度分组,布尔值代表每组内此序号元素是否被读取。例如8个true2个false就是将近80%进入训练集 + /// + /// + public CSVReader(IEnumerable takeMask, string path, Device device) + { + this.takeMask = takeMask.ToArray(); + _path = path; + _device = device; + } + + private readonly bool[] takeMask; + private readonly string _path; + private readonly Device _device; + + + public IEnumerable Enumerate() + { + var all = File.ReadLines(_path).Skip(1); // 跳过首行列名 + int count = takeMask.Length; + int maskCount = all.Count() / count; + for (int i = 0; i < maskCount + 1; i++) + { + int lastGroupCount = (i == maskCount) ? (all.Count() % count) : count; + for (int j = 0; j < lastGroupCount; j++) + { + if (takeMask[j]) + { + yield return ParseLine(all.Skip(i * count + j).First()); + } + } + } + } + + public IEnumerable<(Tensor, Tensor, Tensor, Tensor, Tensor)> GetBatches(long batch_size) + { + // This data set fits in memory, so we will simply load it all and cache it between epochs. + + var inputs = new List(); + + if (_data == null) + { + + _data = new List<(Tensor, Tensor, Tensor, Tensor, Tensor)>(); + + var counter = 0; + var lines = Enumerate().ToList(); + var left = lines.Count; + + foreach (var line in lines) + { + + inputs.Add(line); + left -= 1; + + if (++counter == batch_size || left == 0) + { + _data.Add(Batchifier(inputs)); + inputs.Clear(); + counter = 0; + } + } + } + + return _data; + } + + private List<(Tensor, Tensor, Tensor, Tensor, Tensor)> _data; + private bool disposedValue; + + /// + /// 将csv中的数据进行初步转换 + /// + /// + /// y0z0t、uv、h、fish_label、success张量 + private (Tensor, Tensor, Tensor, Tensor, Tensor) Batchifier(IEnumerable input) + { + var y0List = new List(); + var z0List = new List(); + var tList = new List(); + var uList = new List(); + var vList = new List(); + var hList = new List(); + var labelList = new List(); + var successList = new List(); + + foreach (var line in input) + { + int fish_label = line.fish_label; + int success = line.success; + RodInput rodInput = new RodInput() + { + rod_x1 = line.rod_x1, + rod_x2 = line.rod_x2, + rod_y1 = line.rod_y1, + rod_y2 = line.rod_y2, + fish_x1 = line.fish_x1, + fish_x2 = line.fish_x2, + fish_y1 = line.fish_y1, + fish_y2 = line.fish_y2 + }; + var (y0, z0, t, u, v, h) = RodNet.GetRodStatePreProcess(rodInput); + + y0List.Add(y0); + z0List.Add(z0); + tList.Add(t); + uList.Add(u); + vList.Add(v); + hList.Add(h); + labelList.Add(fish_label); + successList.Add(success); + } + + Tensor y0Tensor = tensor(y0List, dtype: ScalarType.Float64).to(_device); + Tensor z0Tensor = tensor(z0List, dtype: ScalarType.Float64).to(_device); + Tensor tTensor = tensor(tList, dtype: ScalarType.Float64).to(_device); + Tensor uTensor = tensor(uList, dtype: ScalarType.Float64).to(_device); + Tensor vTensor = tensor(vList, dtype: ScalarType.Float64).to(_device); + Tensor hTensor = tensor(hList, dtype: ScalarType.Float64).to(_device); + Tensor fish_labelTensor = tensor(labelList, dtype: ScalarType.Int32).to(_device); + Tensor successTensor = tensor(successList, dtype: ScalarType.Int32).to(_device); + + return (torch.stack([y0Tensor, z0Tensor, tTensor], dim: 1), torch.stack([uTensor, vTensor], dim: 1), hTensor.unsqueeze(1), fish_labelTensor, successTensor); + } + + /// + /// csv的列定义默认为time,bite_time,rod_x1,rod_x2,rod_y1,rod_y2,fish_x1,fish_x2,fish_y1,fish_y2,fish_label,success + /// + /// + /// + public Data ParseLine(string line) + { + var columns = line.Split(",").ToArray(); + + return new Data(float.Parse(columns[2]), float.Parse(columns[3]), float.Parse(columns[4]), float.Parse(columns[5]), + float.Parse(columns[6]), float.Parse(columns[7]), float.Parse(columns[8]), float.Parse(columns[9]), + int.Parse(columns[10]), int.Parse(columns[11])); + } + + protected virtual void Dispose(bool disposing) + { + if (!disposedValue) + { + if (disposing && _data != null) + { + foreach (var (y0z0t, uv, h, label, success) in _data) + { + y0z0t.Dispose(); + uv.Dispose(); + h.Dispose(); + label.Dispose(); + success.Dispose(); + } + } + + disposedValue = true; + } + } + + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + } + + internal record Data(float rod_x1, float rod_x2, float rod_y1, float rod_y2, float fish_x1, float fish_x2, float fish_y1, float fish_y2, int fish_label, int success); +} diff --git a/Test/BetterGenshinImpact.UnitTest/GameTaskTests/AutoFishingTests/RodNetTests.cs b/Test/BetterGenshinImpact.UnitTest/GameTaskTests/AutoFishingTests/RodNetTests.cs index d4546c83..6440c8e9 100644 --- a/Test/BetterGenshinImpact.UnitTest/GameTaskTests/AutoFishingTests/RodNetTests.cs +++ b/Test/BetterGenshinImpact.UnitTest/GameTaskTests/AutoFishingTests/RodNetTests.cs @@ -1,17 +1,16 @@ using BetterGenshinImpact.GameTask.AutoFishing; -using BetterGenshinImpact.UnitTest.GameTaskTests.AutoFishingTests; using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; -using static BetterGenshinImpact.GameTask.AutoFishing.RodNet; +using TorchSharp; using static TorchSharp.torch; namespace BetterGenshinImpact.UnitTest.GameTaskTests.AutoFishingTests { [Collection("Init Collection")] - public class RodNetTests + public partial class RodNetTests { public RodNetTests(TorchFixture torch) { @@ -42,9 +41,8 @@ namespace BetterGenshinImpact.UnitTest.GameTaskTests.AutoFishingTests RodNet sut = new RodNet(); // - NetInput netInput = GeometryProcessing(rodInput) ?? throw new NullReferenceException(); - Tensor outputTensor = sut.ComputeScores_Torch(netInput); - double[] pred = ComputeScores(netInput); + Tensor outputTensor = sut.ComputeScores_Torch(rodInput); + double[] pred = RodNet.ComputeScores(rodInput); // Assert.Equal((float)pred[0], (float)outputTensor.data()[0]); // 对比时降低精度,差不多就行