using BetterGenshinImpact.Core.Config; using BetterGenshinImpact.Core.Recognition.OCR; using BetterGenshinImpact.Core.Simulator; using BetterGenshinImpact.GameTask.AutoArtifactSalvage; using BetterGenshinImpact.GameTask.Common.Job; using BetterGenshinImpact.GameTask.Model.Area; using BetterGenshinImpact.GameTask.Model.GameUI; using BetterGenshinImpact.Helpers.Extensions; using BetterGenshinImpact.View.Drawable; using Fischless.WindowsInput; using Microsoft.Extensions.Logging; using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; using OpenCvSharp; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Threading; using System.Threading.Tasks; using static BetterGenshinImpact.GameTask.Common.TaskControl; namespace BetterGenshinImpact.GameTask.GetGridIcons; /// /// 获取Grid界面的物品图标 /// public class GridIconsAccuracyTestTask : ISoloTask { private readonly ILogger logger = App.GetLogger(); private readonly InputSimulator input = Simulation.SendInput; private CancellationToken ct; public string Name => "获取Grid界面物品图标独立任务"; private readonly int? maxNumToTest; private readonly GridScreenName gridScreenName; public GridIconsAccuracyTestTask(GridScreenName gridScreenName, int? maxNumToTest = null) { this.gridScreenName = gridScreenName; this.maxNumToTest = maxNumToTest; } /// /// 加载图标识别模型 /// /// 原型向量 /// 推理会话 /// public static InferenceSession LoadModel(out Dictionary prototypes) { #region 加载model var session = new InferenceSession(Global.Absolute(@"Assets\Model\Item\gridIcon.onnx")); var metadata = session.ModelMetadata; if (!metadata.CustomMetadataMap.TryGetValue("prefix_list", out string? prefixListJson)) { throw new Exception("模型文件缺少prefix_list"); } List prefixList = System.Text.Json.JsonSerializer.Deserialize>(prefixListJson) ?? throw new Exception(); // 不预测前缀 #endregion #region 加载原型向量 var allLines = File.ReadLines(Global.Absolute(@"Assets\Model\Item\items.csv")).Skip(1); // 跳过首行列名 prototypes = new Dictionary(); foreach (string line in allLines) { var columns = line.Split(",").ToArray(); var bytes = Convert.FromBase64String(columns[1]); int totalFloats = bytes.Length / sizeof(float); float[] flatData = new float[totalFloats]; Buffer.BlockCopy(bytes, 0, flatData, 0, bytes.Length); prototypes.Add(columns[0], flatData); } #endregion return session; } public async Task Start(CancellationToken ct) { this.ct = ct; switch (this.gridScreenName) { case GridScreenName.Weapons: case GridScreenName.Artifacts: case GridScreenName.CharacterDevelopmentItems: case GridScreenName.Food: case GridScreenName.Materials: case GridScreenName.Gadget: case GridScreenName.Quest: case GridScreenName.PreciousItems: case GridScreenName.Furnishings: await new ReturnMainUiTask().Start(ct); await AutoArtifactSalvageTask.OpenInventory(this.gridScreenName, this.input, this.logger, this.ct); break; default: logger.LogInformation("{name}暂不支持自动打开,请提前手动打开界面", gridScreenName.GetDescription()); break; } using InferenceSession session = LoadModel(out Dictionary prototypes); int count = this.maxNumToTest ?? int.MaxValue; double total_acc = 0.0; double total_count = 0; GridScreen gridScreen = new GridScreen(GridParams.Templates[this.gridScreenName], this.logger, this.ct); gridScreen.OnAfterTurnToNewPage += GridScreen.DrawItemsAfterTurnToNewPage; gridScreen.OnBeforeScroll += () => VisionContext.Instance().DrawContent.ClearAll(); try { await foreach ((ImageRegion pageRegion, Rect itemRect) in gridScreen) { using ImageRegion itemRegion = pageRegion.DeriveCrop(itemRect); itemRegion.Click(); Task task1 = Delay(300, ct); // 用模型推理得到的结果 Task<(string?, int)> task2 = Task.Run(() => { using Mat icon = itemRegion.SrcMat.GetGridIcon(); return Infer(icon, session, prototypes); }, ct); await Task.WhenAll(task1, task2); (string?, int) result = task2.Result; string? predName = result.Item1; int predStarNum = result.Item2; // 用CV方法得到的结果 using var ra1 = CaptureToRectArea(); using ImageRegion nameRegion = ra1.DeriveCrop(new Rect((int)(ra1.Width * 0.682), (int)(ra1.Width * 0.0625), (int)(ra1.Width * 0.256), (int)(ra1.Width * 0.03125))); var ocrResult = OcrFactory.Paddle.OcrResult(nameRegion.SrcMat); string itemName = ocrResult.Text; using ImageRegion starRegion = ra1.DeriveCrop(new Rect((int)(ra1.Width * 0.682), (int)(ra1.Width * 0.1823), (int)(ra1.Width * 0.105), (int)(ra1.Width * 0.02345))); int itemStarNum = GetGridIconsTask.GetStars(starRegion.SrcMat); // 统计结果 total_count++; if (predName == null) { logger.LogInformation($"模型没有识别,应为:{itemName}|{itemStarNum}星,❌,正确率{total_acc / total_count:0.00}"); } else if (itemName.Contains(predName) && predStarNum == itemStarNum) { total_acc++; logger.LogInformation($"{predName}|{predStarNum}星,✔,正确率{total_acc / total_count:0.00}"); } else { logger.LogInformation($"{predName}|{predStarNum}星,应为:{itemName}|{itemStarNum}星,❌,正确率{total_acc / total_count:0.00}"); } count--; if (count <= 0) { logger.LogInformation("检查次数已耗尽"); break; } } } finally { VisionContext.Instance().DrawContent.ClearAll(); } } /// /// 请自行裁剪缩放到125*125尺寸 /// /// /// /// /// (预测名称, 预测星级) /// public static (string?, int) Infer(Mat mat, InferenceSession session, Dictionary prototypes) { if (mat.Size().Width != 125 || mat.Size().Height != 125) { throw new ArgumentOutOfRangeException(nameof(mat), "输入图像尺寸应为125*125"); } using Mat rgb = mat.CvtColor(ColorConversionCodes.BGR2RGB); var tensor = new DenseTensor(new[] { 1, 3, rgb.Height, rgb.Width }); // todo 放到BgiOnnxFactory那边去做个Mat->NamedOnnxValue的通用方法? for (int y = 0; y < rgb.Height; y++) { for (int x = 0; x < rgb.Width; x++) { tensor[0, 0, y, x] = rgb.At(y, x)[0] / 255f; tensor[0, 1, y, x] = rgb.At(y, x)[1] / 255f; tensor[0, 2, y, x] = rgb.At(y, x)[2] / 255f; } } var inputs = new List { NamedOnnxValue.CreateFromTensor("input_image", tensor) }; using var results = session.Run(inputs); float[] feature_matrix = results[0].AsEnumerable().ToArray(); string? pred_name = null; double? min2 = null; foreach (KeyValuePair prototype in prototypes) { double distance2 = 0; for (int i = 0; i < 64; i++) { distance2 += Math.Pow(prototype.Value[i] - feature_matrix[i], 2f); } if (min2 == null || distance2 < min2) { min2 = distance2; if (min2 < 10 * 10) // todo:负样本距离10直接读取模型 { pred_name = prototype.Key; } } } if (min2 == null) { throw new Exception("特征数据为空"); } // min2 = Math.Sqrt(min2.Value); int pred_star = results[2].AsEnumerable().ToList().IndexOf(results[2].AsEnumerable().Max()); return (pred_name, pred_star); } }