ONNX test case passed

This commit is contained in:
huiyadanli
2023-10-06 18:17:49 +08:00
parent 4bad84970a
commit 2b11eb3a91
7 changed files with 233 additions and 6 deletions

View File

@@ -24,6 +24,8 @@
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="7.0.0" />
<PackageReference Include="Microsoft.Extensions.Hosting" Version="7.0.1" />
<PackageReference Include="Microsoft.Extensions.Logging" Version="7.0.0" />
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="1.16.0" />
<PackageReference Include="Microsoft.ML.OnnxRuntime.Managed" Version="1.16.0" />
<PackageReference Include="OpenCvSharp4.Windows" Version="4.8.0.20230708" />
<PackageReference Include="Microsoft.Windows.CsWin32" Version="0.3.46-beta">
<PrivateAssets>all</PrivateAssets>

View File

@@ -0,0 +1,200 @@
using Microsoft.ML.OnnxRuntime;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Drawing;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using BetterGenshinImpact.Core.Config;
using Microsoft.ML.OnnxRuntime.Tensors;
using OpenCvSharp;
using BetterGenshinImpact.Core.Recognition.OpenCv;
using System.Text.Json;
using Range = OpenCvSharp.Range;
using Size = OpenCvSharp.Size;
using System.Drawing.Imaging;
namespace BetterGenshinImpact.Core.Recognition.ONNX.SVTR;
public class SvtrModelRunner
{
private readonly InferenceSession _session;
private readonly Dictionary<int, string> _wordDictionary;
public SvtrModelRunner()
{
var options = new SessionOptions();
_session = new InferenceSession(Global.Absolute("Config\\Model\\Yap\\model_training.onnx"), options);
// 获取模型的输入节点名称
var inputName = _session.InputMetadata.Keys.First();
Debug.WriteLine($"Input Name:{inputName}");
var json = File.ReadAllText(Global.Absolute("Config\\Model\\Yap\\index_2_word.json"));
_wordDictionary = JsonSerializer.Deserialize<Dictionary<int, string>>(json);
if (_wordDictionary == null)
{
throw new Exception("index_2_word.json deserialize failed");
}
}
public static Tensor<float> ToOnnxTensorUnsafe(Mat padded)
{
// Create the Tensor with the appropiate dimensions for the NN
Tensor<float> data = new DenseTensor<float>(new[] { 1, 1, padded.Height, padded.Width });
// Rows = height, Cols = width
var imageData = new float[padded.Height, padded.Width];
for (int y = 0; y < padded.Rows; y++)
{
for (int x = 0; x < padded.Cols; x++)
{
byte b = padded.At<byte>(y, x);
data[0, 0, y, x] = b / 255f;
imageData[y, x] = b;
}
}
Debug.WriteLine(imageData);
return data;
}
public static Tensor<float> ToOnnxTensorUnsafe2(Mat padded)
{
var channels = padded.Channels();
var nRows = padded.Rows;
var nCols = padded.Cols * channels;
if (padded.IsContinuous())
{
nCols *= nRows;
nRows = 1;
}
var inputData = new float[nCols];
unsafe
{
for (var i = 0; i < nRows; i++)
{
var p = padded.Ptr(i);
var b = (byte*)p.ToPointer();
for (var j = 0; j < nCols; j++)
{
inputData[j] = b[j] / 255f;
}
}
}
return new DenseTensor<float>(new Memory<float>(inputData), new int[] { 1, 1, 32, 384 });
;
}
public string RunInference(Mat padded)
{
// 将输入数据调整为 (1, 1, 32, 384) 形状的张量
//var reshapedInputData = new DenseTensor<float>(new Memory<float>(inputData), new int[] { 1, 1, 32, 384 });
var reshapedInputData = ToOnnxTensorUnsafe2(padded);
var y = reshapedInputData.Dimensions[2];
var x = reshapedInputData.Dimensions[3];
// 创建输入 NamedOnnxValue
var inputs = new List<NamedOnnxValue> { NamedOnnxValue.CreateFromTensor("input", reshapedInputData) };
// 运行模型推理
using var results = _session.Run(inputs);
// 获取输出数据
var resultsArray = results.ToArray();
Tensor<float> boxes = resultsArray[0].AsTensor<float>();
var ans = "";
var lastWord = "";
for (var i = 0; i < boxes.Dimensions[0]; i++)
{
var maxIndex = 0;
var maxValue = -1.0;
for (var j = 0; j < _wordDictionary.Count; j++)
{
var value = boxes[i, 0, j];
if (value > maxValue)
{
maxValue = value;
maxIndex = j;
}
}
var word = _wordDictionary[maxIndex];
if (word != lastWord && word != "|")
{
ans = ans + word;
}
lastWord = word;
}
Debug.WriteLine("ans:" + ans);
return ans;
}
public string RunInferenceMore(Mat mat)
{
Debug.Assert(mat.Depth() == MatType.CV_8UC1);
//Cv2.ImShow("mat1", mat);
mat = ResizeHelper.ResizeTo(mat, 221, 32);
Cv2.ImWrite(Global.Absolute("resized.bmp"), mat);
var padded = new Mat(new Size(384, 32), MatType.CV_8UC1, Scalar.Black);
padded[new Rect(0, 0, mat.Width, mat.Height)] = mat;
Cv2.ImWrite(Global.Absolute("padded.bmp"), padded);
/*var channels = padded.Channels();
var nRows = padded.Rows;
var nCols = padded.Cols * channels;
if (padded.IsContinuous())
{
nCols *= nRows;
nRows = 1;
}
var inputData = new float[nCols];
unsafe
{
for (var i = 0; i < nRows; i++)
{
var p = padded.Ptr(i);
var b = (byte*)p.ToPointer();
for (var j = 0; j < nCols; j++)
{
inputData[j] = b[j] * 0.1f / 255;
}
}
}*/
//var imageData = new float[384, 32];
// Rows = height, Cols = width
//for (int y = 0; y < padded.Rows; y++)
//{
// for (int x = 0; x < padded.Cols; x++)
// {
// imageData[x, y] = padded.At<byte>(x, y) * 0.1f / 255;
// }
//}
//var inputData = new float[384 * 32]; // 你的输入数据
//for (int y = 0; y < padded.Rows; y++)
//{
// for (int x = 0; x < padded.Cols; x++)
// {
// inputData[y * 384 + x] = imageData[x, y];
// }
//}
return RunInference(padded);
}
}

View File

@@ -35,5 +35,16 @@ namespace BetterGenshinImpact.Core.Recognition.OpenCv
}
return src;
}
public static Mat ResizeTo(Mat src, int width, int height)
{
if (src.Width != width || src.Height != height)
{
var dst = new Mat();
Cv2.Resize(src, dst, new Size(width, height));
return dst;
}
return src;
}
}
}

View File

@@ -55,6 +55,11 @@ namespace BetterGenshinImpact.GameTask.Model
// 注意截图区域要和游戏窗口实际区域一致
// todo 窗口移动后?
GameScreenSize = SystemControl.GetGameScreenRect(hWnd);
if (GameScreenSize.Width < 800)
{
throw new ArgumentException("游戏窗口分辨率过低,请确认当前原神窗口是否处于最小化状态!");
}
AssetScale = GameScreenSize.Width / 1920d;
GameWindowRect = SystemControl.GetWindowRect(hWnd);
CaptureAreaRect = GameWindowRect;

View File

@@ -43,6 +43,6 @@
</StackPanel>
<ui:Button x:Name="StartCaptureTest" Margin="0,20,0,0" Content="测试图像捕获" Command="{Binding StartCaptureTestCommand}" />
<ui:Button x:Name="SvtrTest" Margin="0,20,0,0" Content="SVTR测试" Command="{Binding SvtrTestCommand}" />
</StackPanel>
</Page>

View File

@@ -20,6 +20,8 @@ using CommunityToolkit.Mvvm.Messaging;
using Wpf.Ui.Controls;
using BetterGenshinImpact.Service.Interface;
using BetterGenshinImpact.Core.Config;
using BetterGenshinImpact.Core.Recognition.ONNX.SVTR;
using OpenCvSharp;
namespace BetterGenshinImpact.ViewModel.Pages;
@@ -65,11 +67,11 @@ public partial class HomePageViewModel : ObservableObject, INavigationAware
{
Debug.WriteLine("HomePageViewModel Loaded");
#if DEBUG
var hWnd = SystemControl.FindGenshinImpactHandle();
if (hWnd != IntPtr.Zero)
{
OnStartTrigger();
}
//var hWnd = SystemControl.FindGenshinImpactHandle();
//if (hWnd != IntPtr.Zero)
//{
// OnStartTrigger();
//}
#endif
}
@@ -143,4 +145,11 @@ public partial class HomePageViewModel : ObservableObject, INavigationAware
public void OnNavigatedFrom()
{
}
[RelayCommand]
private void OnSvtrTest()
{
var mat = new Mat(Global.Absolute("Config\\Model\\Yap\\0_2_「甜甜花」的种子_bin_38x227.jpg"), ImreadModes.Grayscale);
new SvtrModelRunner().RunInferenceMore(mat);
}
}