mirror of
https://github.com/babalae/better-genshin-impact.git
synced 2026-05-10 00:44:10 +08:00
ONNX test case passed
This commit is contained in:
@@ -24,6 +24,8 @@
|
||||
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="7.0.0" />
|
||||
<PackageReference Include="Microsoft.Extensions.Hosting" Version="7.0.1" />
|
||||
<PackageReference Include="Microsoft.Extensions.Logging" Version="7.0.0" />
|
||||
<PackageReference Include="Microsoft.ML.OnnxRuntime" Version="1.16.0" />
|
||||
<PackageReference Include="Microsoft.ML.OnnxRuntime.Managed" Version="1.16.0" />
|
||||
<PackageReference Include="OpenCvSharp4.Windows" Version="4.8.0.20230708" />
|
||||
<PackageReference Include="Microsoft.Windows.CsWin32" Version="0.3.46-beta">
|
||||
<PrivateAssets>all</PrivateAssets>
|
||||
|
||||
@@ -0,0 +1,200 @@
|
||||
using Microsoft.ML.OnnxRuntime;
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Drawing;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using BetterGenshinImpact.Core.Config;
|
||||
using Microsoft.ML.OnnxRuntime.Tensors;
|
||||
using OpenCvSharp;
|
||||
using BetterGenshinImpact.Core.Recognition.OpenCv;
|
||||
using System.Text.Json;
|
||||
using Range = OpenCvSharp.Range;
|
||||
using Size = OpenCvSharp.Size;
|
||||
using System.Drawing.Imaging;
|
||||
|
||||
namespace BetterGenshinImpact.Core.Recognition.ONNX.SVTR;
|
||||
|
||||
public class SvtrModelRunner
|
||||
{
|
||||
private readonly InferenceSession _session;
|
||||
private readonly Dictionary<int, string> _wordDictionary;
|
||||
|
||||
public SvtrModelRunner()
|
||||
{
|
||||
var options = new SessionOptions();
|
||||
_session = new InferenceSession(Global.Absolute("Config\\Model\\Yap\\model_training.onnx"), options);
|
||||
// 获取模型的输入节点名称
|
||||
var inputName = _session.InputMetadata.Keys.First();
|
||||
Debug.WriteLine($"Input Name:{inputName}");
|
||||
|
||||
var json = File.ReadAllText(Global.Absolute("Config\\Model\\Yap\\index_2_word.json"));
|
||||
_wordDictionary = JsonSerializer.Deserialize<Dictionary<int, string>>(json);
|
||||
if (_wordDictionary == null)
|
||||
{
|
||||
throw new Exception("index_2_word.json deserialize failed");
|
||||
}
|
||||
}
|
||||
|
||||
public static Tensor<float> ToOnnxTensorUnsafe(Mat padded)
|
||||
{
|
||||
// Create the Tensor with the appropiate dimensions for the NN
|
||||
Tensor<float> data = new DenseTensor<float>(new[] { 1, 1, padded.Height, padded.Width });
|
||||
|
||||
|
||||
// Rows = height, Cols = width
|
||||
var imageData = new float[padded.Height, padded.Width];
|
||||
for (int y = 0; y < padded.Rows; y++)
|
||||
{
|
||||
for (int x = 0; x < padded.Cols; x++)
|
||||
{
|
||||
byte b = padded.At<byte>(y, x);
|
||||
data[0, 0, y, x] = b / 255f;
|
||||
imageData[y, x] = b;
|
||||
}
|
||||
}
|
||||
|
||||
Debug.WriteLine(imageData);
|
||||
return data;
|
||||
}
|
||||
|
||||
public static Tensor<float> ToOnnxTensorUnsafe2(Mat padded)
|
||||
{
|
||||
var channels = padded.Channels();
|
||||
var nRows = padded.Rows;
|
||||
var nCols = padded.Cols * channels;
|
||||
if (padded.IsContinuous())
|
||||
{
|
||||
nCols *= nRows;
|
||||
nRows = 1;
|
||||
}
|
||||
|
||||
var inputData = new float[nCols];
|
||||
unsafe
|
||||
{
|
||||
for (var i = 0; i < nRows; i++)
|
||||
{
|
||||
var p = padded.Ptr(i);
|
||||
var b = (byte*)p.ToPointer();
|
||||
for (var j = 0; j < nCols; j++)
|
||||
{
|
||||
inputData[j] = b[j] / 255f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return new DenseTensor<float>(new Memory<float>(inputData), new int[] { 1, 1, 32, 384 });
|
||||
;
|
||||
}
|
||||
|
||||
public string RunInference(Mat padded)
|
||||
{
|
||||
// 将输入数据调整为 (1, 1, 32, 384) 形状的张量
|
||||
//var reshapedInputData = new DenseTensor<float>(new Memory<float>(inputData), new int[] { 1, 1, 32, 384 });
|
||||
var reshapedInputData = ToOnnxTensorUnsafe2(padded);
|
||||
|
||||
var y = reshapedInputData.Dimensions[2];
|
||||
var x = reshapedInputData.Dimensions[3];
|
||||
|
||||
// 创建输入 NamedOnnxValue
|
||||
var inputs = new List<NamedOnnxValue> { NamedOnnxValue.CreateFromTensor("input", reshapedInputData) };
|
||||
|
||||
// 运行模型推理
|
||||
using var results = _session.Run(inputs);
|
||||
|
||||
// 获取输出数据
|
||||
var resultsArray = results.ToArray();
|
||||
Tensor<float> boxes = resultsArray[0].AsTensor<float>();
|
||||
|
||||
var ans = "";
|
||||
var lastWord = "";
|
||||
for (var i = 0; i < boxes.Dimensions[0]; i++)
|
||||
{
|
||||
var maxIndex = 0;
|
||||
var maxValue = -1.0;
|
||||
for (var j = 0; j < _wordDictionary.Count; j++)
|
||||
{
|
||||
var value = boxes[i, 0, j];
|
||||
if (value > maxValue)
|
||||
{
|
||||
maxValue = value;
|
||||
maxIndex = j;
|
||||
}
|
||||
}
|
||||
|
||||
var word = _wordDictionary[maxIndex];
|
||||
if (word != lastWord && word != "|")
|
||||
{
|
||||
ans = ans + word;
|
||||
}
|
||||
|
||||
lastWord = word;
|
||||
}
|
||||
|
||||
Debug.WriteLine("ans:" + ans);
|
||||
return ans;
|
||||
}
|
||||
|
||||
public string RunInferenceMore(Mat mat)
|
||||
{
|
||||
Debug.Assert(mat.Depth() == MatType.CV_8UC1);
|
||||
//Cv2.ImShow("mat1", mat);
|
||||
mat = ResizeHelper.ResizeTo(mat, 221, 32);
|
||||
Cv2.ImWrite(Global.Absolute("resized.bmp"), mat);
|
||||
|
||||
var padded = new Mat(new Size(384, 32), MatType.CV_8UC1, Scalar.Black);
|
||||
|
||||
|
||||
padded[new Rect(0, 0, mat.Width, mat.Height)] = mat;
|
||||
Cv2.ImWrite(Global.Absolute("padded.bmp"), padded);
|
||||
|
||||
/*var channels = padded.Channels();
|
||||
var nRows = padded.Rows;
|
||||
var nCols = padded.Cols * channels;
|
||||
if (padded.IsContinuous())
|
||||
{
|
||||
nCols *= nRows;
|
||||
nRows = 1;
|
||||
}
|
||||
|
||||
var inputData = new float[nCols];
|
||||
unsafe
|
||||
{
|
||||
for (var i = 0; i < nRows; i++)
|
||||
{
|
||||
var p = padded.Ptr(i);
|
||||
var b = (byte*)p.ToPointer();
|
||||
for (var j = 0; j < nCols; j++)
|
||||
{
|
||||
inputData[j] = b[j] * 0.1f / 255;
|
||||
}
|
||||
}
|
||||
}*/
|
||||
|
||||
|
||||
//var imageData = new float[384, 32];
|
||||
// Rows = height, Cols = width
|
||||
//for (int y = 0; y < padded.Rows; y++)
|
||||
//{
|
||||
// for (int x = 0; x < padded.Cols; x++)
|
||||
// {
|
||||
// imageData[x, y] = padded.At<byte>(x, y) * 0.1f / 255;
|
||||
// }
|
||||
//}
|
||||
|
||||
|
||||
//var inputData = new float[384 * 32]; // 你的输入数据
|
||||
//for (int y = 0; y < padded.Rows; y++)
|
||||
//{
|
||||
// for (int x = 0; x < padded.Cols; x++)
|
||||
// {
|
||||
// inputData[y * 384 + x] = imageData[x, y];
|
||||
// }
|
||||
//}
|
||||
|
||||
return RunInference(padded);
|
||||
}
|
||||
}
|
||||
@@ -35,5 +35,16 @@ namespace BetterGenshinImpact.Core.Recognition.OpenCv
|
||||
}
|
||||
return src;
|
||||
}
|
||||
|
||||
public static Mat ResizeTo(Mat src, int width, int height)
|
||||
{
|
||||
if (src.Width != width || src.Height != height)
|
||||
{
|
||||
var dst = new Mat();
|
||||
Cv2.Resize(src, dst, new Size(width, height));
|
||||
return dst;
|
||||
}
|
||||
return src;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,6 +55,11 @@ namespace BetterGenshinImpact.GameTask.Model
|
||||
// 注意截图区域要和游戏窗口实际区域一致
|
||||
// todo 窗口移动后?
|
||||
GameScreenSize = SystemControl.GetGameScreenRect(hWnd);
|
||||
if (GameScreenSize.Width < 800)
|
||||
{
|
||||
throw new ArgumentException("游戏窗口分辨率过低,请确认当前原神窗口是否处于最小化状态!");
|
||||
}
|
||||
|
||||
AssetScale = GameScreenSize.Width / 1920d;
|
||||
GameWindowRect = SystemControl.GetWindowRect(hWnd);
|
||||
CaptureAreaRect = GameWindowRect;
|
||||
|
||||
@@ -43,6 +43,6 @@
|
||||
</StackPanel>
|
||||
|
||||
<ui:Button x:Name="StartCaptureTest" Margin="0,20,0,0" Content="测试图像捕获" Command="{Binding StartCaptureTestCommand}" />
|
||||
|
||||
<ui:Button x:Name="SvtrTest" Margin="0,20,0,0" Content="SVTR测试" Command="{Binding SvtrTestCommand}" />
|
||||
</StackPanel>
|
||||
</Page>
|
||||
@@ -20,6 +20,8 @@ using CommunityToolkit.Mvvm.Messaging;
|
||||
using Wpf.Ui.Controls;
|
||||
using BetterGenshinImpact.Service.Interface;
|
||||
using BetterGenshinImpact.Core.Config;
|
||||
using BetterGenshinImpact.Core.Recognition.ONNX.SVTR;
|
||||
using OpenCvSharp;
|
||||
|
||||
namespace BetterGenshinImpact.ViewModel.Pages;
|
||||
|
||||
@@ -65,11 +67,11 @@ public partial class HomePageViewModel : ObservableObject, INavigationAware
|
||||
{
|
||||
Debug.WriteLine("HomePageViewModel Loaded");
|
||||
#if DEBUG
|
||||
var hWnd = SystemControl.FindGenshinImpactHandle();
|
||||
if (hWnd != IntPtr.Zero)
|
||||
{
|
||||
OnStartTrigger();
|
||||
}
|
||||
//var hWnd = SystemControl.FindGenshinImpactHandle();
|
||||
//if (hWnd != IntPtr.Zero)
|
||||
//{
|
||||
// OnStartTrigger();
|
||||
//}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -143,4 +145,11 @@ public partial class HomePageViewModel : ObservableObject, INavigationAware
|
||||
public void OnNavigatedFrom()
|
||||
{
|
||||
}
|
||||
|
||||
[RelayCommand]
|
||||
private void OnSvtrTest()
|
||||
{
|
||||
var mat = new Mat(Global.Absolute("Config\\Model\\Yap\\0_2_「甜甜花」的种子_bin_38x227.jpg"), ImreadModes.Grayscale);
|
||||
new SvtrModelRunner().RunInferenceMore(mat);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user