using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static TorchSharp.torch.nn;
using static TorchSharp.torch;
using TorchSharp;
using System.Diagnostics;
using System.Collections;
using BetterGenshinImpact.GameTask.AutoFishing;
namespace BetterGenshinImpact.UnitTest.GameTaskTests.AutoFishingTests
{
public partial class RodNetTests
{
///
/// RodNet验证,应在数据集上达到一定准确率
///
[Theory]
[InlineData(@"..\..\..\Assets\AutoFishing\data_selected.csv")]
public void Training_AccuracyShouldBeOK(string dataLocation)
{
//
using var _ = no_grad();
var device =
torch.cuda.is_available() ? torch.CUDA :
torch.mps_is_available() ? torch.MPS :
torch.CPU;
var loss = CrossEntropyLoss();
var sut = new RodNet().to((Device)device);
sut.SetWeightsManually();
using var test_reader = new CSVReader(Enumerable.Repeat(false, 8).Concat(Enumerable.Repeat(true, 2)), Path.GetFullPath(dataLocation), (Device)device);
//
var accuracy = evaluate(test_reader.GetBatches(eval_batch_size), sut, loss);
//
Assert.True(accuracy > 0.8);
}
///
/// RodNet必须粗略地支持训练
///
[Fact]
public void Training_ShouldBeDifferentiable()
{
//
RodInput input = new RodInput();
var (y0, z0, t, u, v, h) = RodNet.GetRodStatePreProcess(input);
Tensor fishLabel = tensor(new double[] { input.fish_label }, dtype: ScalarType.Int32);
Tensor uv = tensor(new double[,] { { u, v } }, dtype: ScalarType.Float64);
Tensor y0z0t = tensor(new double[,] { { y0, z0, t } }, dtype: ScalarType.Float64);
Tensor h_ = tensor(new double[,] { { h } }, dtype: ScalarType.Float64);
RodNet sut = new RodNet();
//
Tensor output = sut.forward(fishLabel, uv, y0z0t, h_);
output.backward([torch.ones_like(output)]);
//
}
#region 训练相关代码
// 这部分代码改编自TorchSharpExamples的CSharpExamples.TextClassification
private const long batch_size = 32;
private const long eval_batch_size = 32;
internal static RodNet Run(int epochs, int timeout, string dataLocation)
{
torch.random.manual_seed(1);
var device =
torch.cuda.is_available() ? torch.CUDA :
torch.mps_is_available() ? torch.MPS :
torch.CPU;
Console.WriteLine();
Console.WriteLine($"\tRunning TextClassification on {device.type.ToString()} for {epochs} epochs, terminating after {TimeSpan.FromSeconds(timeout)}.");
Console.WriteLine();
Console.WriteLine($"\tPreparing training and test data...");
using (var reader = new CSVReader(Enumerable.Repeat(true, 8).Concat(Enumerable.Repeat(false, 2)), dataLocation, (Device)device))
{
Console.WriteLine($"\tCreating the model...");
Console.WriteLine();
var model = new RodNet().to((Device)device);
var loss = CrossEntropyLoss();
var lr = 1e-2;
var optimizer = torch.optim.SGD(model.parameters(), learningRate: lr);
var scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min: 0);
var totalTime = new Stopwatch();
totalTime.Start();
foreach (var epoch in Enumerable.Range(1, epochs))
{
var sw = new Stopwatch();
sw.Start();
train(epoch, reader.GetBatches(batch_size), model, loss, optimizer);
sw.Stop();
Console.WriteLine($"\nEnd of epoch: {epoch} | lr: {optimizer.ParamGroups.First().LearningRate:0.0000} | time: {sw.Elapsed.TotalSeconds:0.0}s\n");
scheduler.step();
if (totalTime.Elapsed.TotalSeconds > timeout) break;
}
totalTime.Stop();
using (var test_reader = new CSVReader(Enumerable.Repeat(false, 8).Concat(Enumerable.Repeat(true, 2)), dataLocation, (Device)device))
{
var sw = new Stopwatch();
sw.Start();
var accuracy = evaluate(test_reader.GetBatches(eval_batch_size), model, loss);
sw.Stop();
Console.WriteLine($"\nEnd of training: test accuracy: {accuracy:0.00} | eval time: {sw.Elapsed.TotalSeconds:0.0}s\n");
scheduler.step();
}
foreach (var (name, param) in model.named_parameters())
{
switch (param.dtype)
{
case ScalarType.Int64:
Console.WriteLine($"参数{name}={String.Join(", ", param.data())}");
break;
case ScalarType.Float32:
Console.WriteLine($"参数{name}={String.Join(", ", param.data())}");
break;
case ScalarType.Float64:
Console.WriteLine($"参数{name}={String.Join(", ", param.data())}");
break;
}
}
return model;
}
}
static void train(int epoch, IEnumerable<(Tensor, Tensor, Tensor, Tensor, Tensor)> train_data, RodNet model, Loss criterion, torch.optim.Optimizer optimizer)
{
model.train();
double total_acc = 0.0;
long total_count = 0;
long log_interval = 1;
var batch = 0;
var batch_count = train_data.Count();
using (var d = torch.NewDisposeScope())
{
foreach (var (y0z0t, uv, h, fish_label, success) in train_data)
{
optimizer.zero_grad();
using (var predicted_labels = model.forward(fish_label, uv, y0z0t, h))
{
var loss = criterion.forward(predicted_labels, success.to(ScalarType.Int64));
loss.backward();
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1);
optimizer.step();
total_acc += (predicted_labels.argmax(1) == success).sum().to(torch.CPU).item();
total_count += success.size(0);
}
batch += 1;
if (batch % log_interval == 0)
{
var accuracy = total_acc / total_count;
Console.WriteLine($"epoch: {epoch} | batch: {batch} / {batch_count} | accuracy: {accuracy:0.00}");
}
}
}
}
static double evaluate(IEnumerable<(Tensor, Tensor, Tensor, Tensor, Tensor)> test_data, RodNet model, Loss criterion)
{
model.eval();
double total_acc = 0.0;
long total_count = 0;
using (var d = torch.NewDisposeScope())
{
foreach (var (y0z0t, uv, h, fish_label, success) in test_data)
{
using (var predicted_labels = model.forward(fish_label, uv, y0z0t, h))
{
var loss = criterion.forward(predicted_labels, success.to(ScalarType.Int64));
total_acc += (predicted_labels.argmax(1) == success).sum().to(torch.CPU).item();
total_count += success.size(0);
}
}
return total_acc / total_count;
}
}
#endregion
}
internal class CSVReader : IDisposable
{
///
///
///
/// 按长度分组,布尔值代表每组内此序号元素是否被读取。例如8个true2个false就是将近80%进入训练集
///
///
public CSVReader(IEnumerable takeMask, string path, Device device)
{
this.takeMask = takeMask.ToArray();
_path = path;
_device = device;
}
private readonly bool[] takeMask;
private readonly string _path;
private readonly Device _device;
public IEnumerable Enumerate()
{
var all = File.ReadLines(_path).Skip(1); // 跳过首行列名
int count = takeMask.Length;
int maskCount = all.Count() / count;
for (int i = 0; i < maskCount + 1; i++)
{
int lastGroupCount = (i == maskCount) ? (all.Count() % count) : count;
for (int j = 0; j < lastGroupCount; j++)
{
if (takeMask[j])
{
yield return ParseLine(all.Skip(i * count + j).First());
}
}
}
}
public IEnumerable<(Tensor, Tensor, Tensor, Tensor, Tensor)> GetBatches(long batch_size)
{
// This data set fits in memory, so we will simply load it all and cache it between epochs.
var inputs = new List();
if (_data == null)
{
_data = new List<(Tensor, Tensor, Tensor, Tensor, Tensor)>();
var counter = 0;
var lines = Enumerate().ToList();
var left = lines.Count;
foreach (var line in lines)
{
inputs.Add(line);
left -= 1;
if (++counter == batch_size || left == 0)
{
_data.Add(Batchifier(inputs));
inputs.Clear();
counter = 0;
}
}
}
return _data;
}
private List<(Tensor, Tensor, Tensor, Tensor, Tensor)> _data;
private bool disposedValue;
///
/// 将csv中的数据进行初步转换
///
///
/// y0z0t、uv、h、fish_label、success张量
private (Tensor, Tensor, Tensor, Tensor, Tensor) Batchifier(IEnumerable input)
{
var y0List = new List();
var z0List = new List();
var tList = new List();
var uList = new List();
var vList = new List();
var hList = new List();
var labelList = new List();
var successList = new List();
foreach (var line in input)
{
int fish_label = line.fish_label;
int success = line.success;
RodInput rodInput = new RodInput()
{
rod_x1 = line.rod_x1,
rod_x2 = line.rod_x2,
rod_y1 = line.rod_y1,
rod_y2 = line.rod_y2,
fish_x1 = line.fish_x1,
fish_x2 = line.fish_x2,
fish_y1 = line.fish_y1,
fish_y2 = line.fish_y2
};
var (y0, z0, t, u, v, h) = RodNet.GetRodStatePreProcess(rodInput);
y0List.Add(y0);
z0List.Add(z0);
tList.Add(t);
uList.Add(u);
vList.Add(v);
hList.Add(h);
labelList.Add(fish_label);
successList.Add(success);
}
Tensor y0Tensor = tensor(y0List, dtype: ScalarType.Float64).to(_device);
Tensor z0Tensor = tensor(z0List, dtype: ScalarType.Float64).to(_device);
Tensor tTensor = tensor(tList, dtype: ScalarType.Float64).to(_device);
Tensor uTensor = tensor(uList, dtype: ScalarType.Float64).to(_device);
Tensor vTensor = tensor(vList, dtype: ScalarType.Float64).to(_device);
Tensor hTensor = tensor(hList, dtype: ScalarType.Float64).to(_device);
Tensor fish_labelTensor = tensor(labelList, dtype: ScalarType.Int32).to(_device);
Tensor successTensor = tensor(successList, dtype: ScalarType.Int32).to(_device);
return (torch.stack([y0Tensor, z0Tensor, tTensor], dim: 1), torch.stack([uTensor, vTensor], dim: 1), hTensor.unsqueeze(1), fish_labelTensor, successTensor);
}
///
/// csv的列定义默认为time,bite_time,rod_x1,rod_x2,rod_y1,rod_y2,fish_x1,fish_x2,fish_y1,fish_y2,fish_label,success
///
///
///
public Data ParseLine(string line)
{
var columns = line.Split(",").ToArray();
return new Data(float.Parse(columns[2]), float.Parse(columns[3]), float.Parse(columns[4]), float.Parse(columns[5]),
float.Parse(columns[6]), float.Parse(columns[7]), float.Parse(columns[8]), float.Parse(columns[9]),
int.Parse(columns[10]), int.Parse(columns[11]));
}
protected virtual void Dispose(bool disposing)
{
if (!disposedValue)
{
if (disposing && _data != null)
{
foreach (var (y0z0t, uv, h, label, success) in _data)
{
y0z0t.Dispose();
uv.Dispose();
h.Dispose();
label.Dispose();
success.Dispose();
}
}
disposedValue = true;
}
}
public void Dispose()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}
internal record Data(float rod_x1, float rod_x2, float rod_y1, float rod_y2, float fish_x1, float fish_x2, float fish_y1, float fish_y2, int fish_label, int success);
}