diff --git a/BetterGenshinImpact/Assets/Model/PaddleOCR/test_ocr.png b/BetterGenshinImpact/Assets/Model/PaddleOCR/test_ocr.png deleted file mode 100644 index 8a70bd7d..00000000 Binary files a/BetterGenshinImpact/Assets/Model/PaddleOCR/test_ocr.png and /dev/null differ diff --git a/BetterGenshinImpact/Assets/Model/PaddleOCR/test_pp_ocr.png b/BetterGenshinImpact/Assets/Model/PaddleOCR/test_pp_ocr.png new file mode 100644 index 00000000..066f6305 Binary files /dev/null and b/BetterGenshinImpact/Assets/Model/PaddleOCR/test_pp_ocr.png differ diff --git a/BetterGenshinImpact/Core/Recognition/OCR/OcrFactory.cs b/BetterGenshinImpact/Core/Recognition/OCR/OcrFactory.cs index dc82f453..7bf1c574 100644 --- a/BetterGenshinImpact/Core/Recognition/OCR/OcrFactory.cs +++ b/BetterGenshinImpact/Core/Recognition/OCR/OcrFactory.cs @@ -1,31 +1,72 @@ using System; +using System.Collections.Concurrent; +using System.Collections.Generic; using System.Threading.Tasks; +using BetterGenshinImpact.GameTask; +using Microsoft.Extensions.Logging; namespace BetterGenshinImpact.Core.Recognition.OCR; public class OcrFactory { // public static IOcrService Media = Create(OcrEngineTypes.Media); - public static IOcrService Paddle { get; private set; } = Create(OcrEngineTypes.Paddle); + private static readonly ILogger Logger = App.GetLogger(); - public static IOcrService Create(OcrEngineTypes type, string? cultureInfoName = null) + public static IOcrService Paddle => _ocrServices.TryGetValue(OcrEngineTypes.Paddle, out var value) + ? value.Value + : CreateAndSet(OcrEngineTypes.Paddle, TaskContext.Instance().Config.OtherConfig.GameCultureInfoName).Value; + + /// + /// 保存着OcrEngineTypes和cultureInfoName与IOcrService + /// + private static readonly ConcurrentDictionary> + _ocrServices = new(); + + /// + /// 创建并设置 + /// + /// OcrEngineTypes + /// 文化名称 + /// cultureInfoName与IOcrService的pair + /// 如果不能创建 + private static KeyValuePair CreateAndSet(OcrEngineTypes type, string cultureInfoName) { - return type switch + var result = type switch { - OcrEngineTypes.Paddle => new PaddleOcrService(cultureInfoName), + OcrEngineTypes.Paddle => new KeyValuePair(cultureInfoName, + new PaddleOcrService(cultureInfoName)), _ => throw new ArgumentOutOfRangeException(nameof(type), type, null) }; + Logger.LogDebug("为 {CultureInfoName} 创建了类型为 {Type} 的 OCR服务", result.Key, result.Value); + _ocrServices[type] = result; + return result; } - public static async Task ChangeCulture(string cultrueInfoName) + private static string? GetCultureInfoName(OcrEngineTypes type) + { + return _ocrServices.TryGetValue(type, out KeyValuePair value) ? value.Key : null; + } + + public static async Task ChangeCulture(string cultureInfoName) { await Task.Run(() => { - lock (Paddle) + foreach (var ocrEngineTypes in Enum.GetValues()) { - Paddle = Create(OcrEngineTypes.Paddle, cultrueInfoName); - GC.Collect(); + try + { + // 避免重复创建OCR服务实例 + if (GetCultureInfoName(ocrEngineTypes) != cultureInfoName) + { + CreateAndSet(ocrEngineTypes, cultureInfoName); + } + } + catch (ArgumentOutOfRangeException) + { + } } + + GC.Collect(); }); } } \ No newline at end of file diff --git a/BetterGenshinImpact/Core/Recognition/OCR/engine/OcrOperationImpl.cs b/BetterGenshinImpact/Core/Recognition/OCR/engine/OcrOperationImpl.cs deleted file mode 100644 index 80217718..00000000 --- a/BetterGenshinImpact/Core/Recognition/OCR/engine/OcrOperationImpl.cs +++ /dev/null @@ -1,41 +0,0 @@ -using OpenCvSharp; - -namespace BetterGenshinImpact.Core.Recognition.OCR.engine; - -/// -/// 实现PPOCR的自定义操作,代码翻自python -/// -public class OcrOperationImpl -{ - /// - /// 不支持 chw 之类的顺序 - /// https://github.com/PaddlePaddle/PaddleOCR/blob/0ee4094988c568077bba35ddb239030ced1ff270/ppocr/data/imaug/operators.py#L62 - /// - public static Mat NormalizeImageOperation(Mat data, - float? scale, // scale float32 - float[]? mean, //mean - float[]? std //std - ) - { - scale ??= 0.00392156862745f; - mean ??= [0.485f, 0.456f, 0.406f]; - std ??= [0.229f, 0.224f, 0.225f]; - var result = new Mat(); - data.ConvertTo(result, MatType.CV_32FC3, (double)scale); - Mat[] bgr = []; - try - { - bgr = result.Split(); - for (var i = 0; i < bgr.Length; ++i) - bgr[i].ConvertTo(bgr[i], MatType.CV_32FC1, 1 / std[i], (0.0 - mean[i]) / std[i]); - - Cv2.Merge(bgr, result); - } - finally - { - foreach (var channel in bgr) channel.Dispose(); - } - - return result; - } -} \ No newline at end of file diff --git a/BetterGenshinImpact/Core/Recognition/OCR/engine/OcrUtils.cs b/BetterGenshinImpact/Core/Recognition/OCR/engine/OcrUtils.cs index f8271b68..620c7190 100644 --- a/BetterGenshinImpact/Core/Recognition/OCR/engine/OcrUtils.cs +++ b/BetterGenshinImpact/Core/Recognition/OCR/engine/OcrUtils.cs @@ -2,14 +2,13 @@ using System; using System.Buffers; using System.Collections.Generic; using System.Linq; -using BetterGenshinImpact.Core.Recognition.OCR.engine; -using BetterGenshinImpact.Core.Recognition.OCR.paddle.data; +using BetterGenshinImpact.Core.Recognition.OCR.engine.data; using BetterGenshinImpact.Core.Recognition.OpenCv; using Microsoft.ML.OnnxRuntime.Tensors; using OpenCvSharp; using OpenCvSharp.Dnn; -namespace BetterGenshinImpact.Core.Recognition.OCR; +namespace BetterGenshinImpact.Core.Recognition.OCR.engine; public static class OcrUtils { @@ -48,7 +47,7 @@ public static class OcrUtils } /// - /// 用于Det模型 + /// 用于Det模型 /// 归一化,标准化并返回Tensor。 ///
/// 归一化:固定范围归一化 @@ -63,51 +62,69 @@ public static class OcrUtils out IMemoryOwner tensorMemoryOwner, bool swapRb = false, bool crop = false, Size size = default) { + scale ??= 0.00392156862745f; + mean ??= [0.485f, 0.456f, 0.406f]; + std ??= [0.229f, 0.224f, 0.225f]; using var rt = new ResourcesTracker(); // 获取图像参数 var channels = src.Channels(); if (channels != 3) throw new ArgumentException($"图像通道数必须为3,当前为{channels}"); - var data = rt.T(OcrOperationImpl.NormalizeImageOperation(src, scale, mean, std)); + // var data = rt.T(OcrOperationImpl.NormalizeImageOperation(src, scale, mean, std)); + var stdMat = rt.NewMat(); + Mat[] bgr = []; + try + { + bgr = src.Split(); + for (var i = 0; i < bgr.Length; ++i) + bgr[i].ConvertTo(bgr[i], MatType.CV_32FC1, 1 / std[i], + (0.0 - mean[i]) / std[i] / (float)scale); + Cv2.Merge(bgr, stdMat); + } + finally + { + foreach (var channel in bgr) channel.Dispose(); + } + + //stdMat.GetArray(out var data); // 使用DNN模块创建blob var blob = rt.T(CvDnn.BlobFromImage( - data, - 1.0, + stdMat, + (double)scale, size, default, swapRb, crop )); + // 租用内存并复制数据 - var total = blob.Total(); - tensorMemoryOwner = MemoryPool.Shared.Rent((int)total); + var total = (int)blob.Total(); + tensorMemoryOwner = MemoryPool.Shared.Rent(total); blob.AsSpan().CopyTo(tensorMemoryOwner.Memory.Span); // 计算输出形状 return new DenseTensor( - tensorMemoryOwner.Memory[..(int)total], - new[] { 1, channels, data.Rows, data.Cols } + tensorMemoryOwner.Memory[..total], + new[] { 1, channels, stdMat.Rows, stdMat.Cols } ); } /// /// 不支持通道转换 - ///
- /// 用于PP-OCR的Rec模型,调整大小之后再归一化到-1~1,之后转换为Tensor + ///
+ /// 用于PP-OCR的Rec模型,调整大小之后再归一化到-1~1,之后转换为Tensor ///
- public static Tensor resize_norm_img(Mat img, OcrShape image_shape, + public static Tensor ResizeNormImg(Mat img, OcrShape imageShape, out IMemoryOwner tensorMemoryOwner, bool padding = true, InterpolationFlags interpolation = InterpolationFlags.Linear) { using var rt = new ResourcesTracker(); - var imgC = image_shape.Channel; - var imgH = image_shape.Height; - var imgW = image_shape.Width; + // var imgC = imageShape.Channel; + var imgH = imageShape.Height; + var imgW = imageShape.Width; var h = img.Height; var w = img.Width; - int resized_w; - var resizedImage = rt.NewMat(); if (!padding) { @@ -117,8 +134,8 @@ public static class OcrUtils else { var ratio = w / (double)h; - resized_w = Math.Ceiling(imgH * ratio) > imgW ? imgW : (int)Math.Ceiling(imgH * ratio); - Cv2.Resize(img, resizedImage, new Size(resized_w, imgH), 0, 0, interpolation); + var resizedW = Math.Ceiling(imgH * ratio) > imgW ? imgW : (int)Math.Ceiling(imgH * ratio); + Cv2.Resize(img, resizedImage, new Size(resizedW, imgH), 0, 0, interpolation); } /* @@ -126,7 +143,7 @@ public static class OcrUtils resized_image -= 0.5 resized_image /= 0.5 */ - // 归一化 + // 归一化到 +-1 // resizedImage.ConvertTo(resizedImage, MatType.CV_32F, 2 / 255f, 1); var blob = rt.T(CvDnn.BlobFromImage( resizedImage, @@ -166,15 +183,12 @@ public static class OcrUtils public static Mat Tensor2Mat(Tensor tensor) { var dimensions = tensor.Dimensions; - if (dimensions.Length !=4 || dimensions[0] != 1 || dimensions[1] != 1) - { + if (dimensions.Length != 4 || dimensions[0] != 1 || dimensions[1] != 1) throw new ArgumentException($"wrong tensor shape: {string.Join(",", dimensions.ToArray())}"); - } if (tensor is not DenseTensor denseTensor) return Mat.FromPixelData(dimensions[2], dimensions[3], MatType.CV_32FC1, tensor.ToArray()); var mat = new Mat(new Size(dimensions[3], dimensions[2]), MatType.CV_32FC1); denseTensor.Buffer.Span.CopyTo(mat.AsSpan()); return mat; - } } \ No newline at end of file diff --git a/BetterGenshinImpact/Core/Recognition/OCR/engine/OcrVersionConfig.cs b/BetterGenshinImpact/Core/Recognition/OCR/engine/OcrVersionConfig.cs index d0c6c58d..3e0f5d84 100644 --- a/BetterGenshinImpact/Core/Recognition/OCR/engine/OcrVersionConfig.cs +++ b/BetterGenshinImpact/Core/Recognition/OCR/engine/OcrVersionConfig.cs @@ -1,9 +1,9 @@ -using BetterGenshinImpact.Core.Recognition.OCR.paddle.data; +using BetterGenshinImpact.Core.Recognition.OCR.engine.data; -namespace BetterGenshinImpact.Core.Recognition.OCR; +namespace BetterGenshinImpact.Core.Recognition.OCR.engine; /// -/// ppocr的版本配置 +/// ppocr的版本配置 /// public readonly record struct OcrVersionConfig( string Name, diff --git a/BetterGenshinImpact/Core/Recognition/OCR/engine/data/OcrImgMode.cs b/BetterGenshinImpact/Core/Recognition/OCR/engine/data/OcrImgMode.cs index 7d734627..f89f09b2 100644 --- a/BetterGenshinImpact/Core/Recognition/OCR/engine/data/OcrImgMode.cs +++ b/BetterGenshinImpact/Core/Recognition/OCR/engine/data/OcrImgMode.cs @@ -1,6 +1,7 @@ -namespace BetterGenshinImpact.Core.Recognition.OCR; +namespace BetterGenshinImpact.Core.Recognition.OCR.engine.data; + /// -/// 图像的颜色顺序 +/// 图像的颜色顺序 /// public enum OcrImgMode { diff --git a/BetterGenshinImpact/Core/Recognition/OCR/engine/data/OcrMatOrder.cs b/BetterGenshinImpact/Core/Recognition/OCR/engine/data/OcrMatOrder.cs index 28f5a899..908f5e4b 100644 --- a/BetterGenshinImpact/Core/Recognition/OCR/engine/data/OcrMatOrder.cs +++ b/BetterGenshinImpact/Core/Recognition/OCR/engine/data/OcrMatOrder.cs @@ -1,8 +1,9 @@ -namespace BetterGenshinImpact.Core.Recognition.OCR; +namespace BetterGenshinImpact.Core.Recognition.OCR.engine.data; + /// -/// Mat的通道顺序 -/// hwc: height width channel -/// chw: channel height width +/// Mat的通道顺序 +/// hwc: height width channel +/// chw: channel height width /// public enum OcrMatOrder { diff --git a/BetterGenshinImpact/Core/Recognition/OCR/engine/data/OcrNormalizeImage.cs b/BetterGenshinImpact/Core/Recognition/OCR/engine/data/OcrNormalizeImage.cs index b5077403..20992bd2 100644 --- a/BetterGenshinImpact/Core/Recognition/OCR/engine/data/OcrNormalizeImage.cs +++ b/BetterGenshinImpact/Core/Recognition/OCR/engine/data/OcrNormalizeImage.cs @@ -1,5 +1,6 @@ -namespace BetterGenshinImpact.Core.Recognition.OCR; +namespace BetterGenshinImpact.Core.Recognition.OCR.engine.data; + /// -/// 标准归一化的三个参数 +/// 标准归一化的三个参数 /// public record OcrNormalizeImage(float Scale, float[] Mean, float[] Std); \ No newline at end of file diff --git a/BetterGenshinImpact/Core/Recognition/OCR/engine/data/OcrShape.cs b/BetterGenshinImpact/Core/Recognition/OCR/engine/data/OcrShape.cs index 82bfe847..cbdf1c37 100644 --- a/BetterGenshinImpact/Core/Recognition/OCR/engine/data/OcrShape.cs +++ b/BetterGenshinImpact/Core/Recognition/OCR/engine/data/OcrShape.cs @@ -1,6 +1,6 @@ -namespace BetterGenshinImpact.Core.Recognition.OCR.paddle.data; +namespace BetterGenshinImpact.Core.Recognition.OCR.engine.data; /// -/// 图像形状表示 +/// 图像形状表示 /// public readonly record struct OcrShape(int Channel, int Width, int Height); \ No newline at end of file diff --git a/BetterGenshinImpact/Core/Recognition/OCR/paddle/Det.cs b/BetterGenshinImpact/Core/Recognition/OCR/paddle/Det.cs index e41af092..2fb4010c 100644 --- a/BetterGenshinImpact/Core/Recognition/OCR/paddle/Det.cs +++ b/BetterGenshinImpact/Core/Recognition/OCR/paddle/Det.cs @@ -1,5 +1,6 @@ using System; using System.Linq; +using BetterGenshinImpact.Core.Recognition.OCR.engine; using BetterGenshinImpact.Core.Recognition.ONNX; using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; @@ -15,7 +16,7 @@ public class Det public Det(BgiOnnxModel model, OcrVersionConfig config) { _config = config; - _session = BgiOnnxFactory.Instance.CreateInferenceSession(model,true); + _session = BgiOnnxFactory.Instance.CreateInferenceSession(model, true); } /// Gets or sets the maximum size for resizing the input image. @@ -38,7 +39,10 @@ public class Det ~Det() { - _session.Dispose(); + lock (_session) + { + _session.Dispose(); + } } public RotatedRect[] Run(Mat src) @@ -124,7 +128,7 @@ public class Det if (output.ValueType is not OnnxValueType.ONNX_TYPE_TENSOR) throw new Exception($"Unexpected output tensor value type: {output.ValueType}"); var outputTensor = output.AsTensor(); - return OcrUtils.Tensor2Mat(tensor: outputTensor); + return OcrUtils.Tensor2Mat(outputTensor); // 因为一个已知bug,tensor中内存在dml下使用完后会被释放掉,锁之外的代码会报错 } } diff --git a/BetterGenshinImpact/Core/Recognition/OCR/paddle/PaddleOcrService.cs b/BetterGenshinImpact/Core/Recognition/OCR/paddle/PaddleOcrService.cs index a208eb56..a55507f8 100644 --- a/BetterGenshinImpact/Core/Recognition/OCR/paddle/PaddleOcrService.cs +++ b/BetterGenshinImpact/Core/Recognition/OCR/paddle/PaddleOcrService.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.IO; using System.Linq; using BetterGenshinImpact.Core.Config; +using BetterGenshinImpact.Core.Recognition.OCR.engine; using BetterGenshinImpact.Core.Recognition.OCR.paddle; using BetterGenshinImpact.Core.Recognition.ONNX; using OpenCvSharp; @@ -17,29 +18,29 @@ public class PaddleOcrService : IOcrService /// 模型列表: /// https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.5/doc/doc_ch/models_list.md ///
- private readonly Det localDetModel; + private readonly Det _localDetModel; - private readonly Rec localRecModel; + private readonly Rec _localRecModel; - public PaddleOcrService(string? cultureInfoName = null) + public PaddleOcrService(string cultureInfoName) { var path = Global.Absolute(@"Assets\Model\PaddleOcr"); switch (cultureInfoName) { case "zh-Hant": - localDetModel = new Det(BgiOnnxModel.PaddleOcrChDet, OcrVersionConfig.PpOcrV4); - localRecModel = new Rec(BgiOnnxModel.PaddleOcrChtRec, Path.Combine(path, "chinese_cht_dict.txt"), + _localDetModel = new Det(BgiOnnxModel.PaddleOcrChDet, OcrVersionConfig.PpOcrV4); + _localRecModel = new Rec(BgiOnnxModel.PaddleOcrChtRec, Path.Combine(path, "chinese_cht_dict.txt"), OcrVersionConfig.PpOcrV3); break; case "fr": - localDetModel = new Det(BgiOnnxModel.PaddleOcrEnDet, OcrVersionConfig.PpOcrV3); - localRecModel = new Rec(BgiOnnxModel.PaddleOcrLatinRec, Path.Combine(path, "latin_dict.txt"), + _localDetModel = new Det(BgiOnnxModel.PaddleOcrEnDet, OcrVersionConfig.PpOcrV3); + _localRecModel = new Rec(BgiOnnxModel.PaddleOcrLatinRec, Path.Combine(path, "latin_dict.txt"), OcrVersionConfig.PpOcrV3); break; default: - localDetModel = new Det(BgiOnnxModel.PaddleOcrChDet, OcrVersionConfig.PpOcrV4); - localRecModel = new Rec(BgiOnnxModel.PaddleOcrChRec, Path.Combine(path, "ppocr_keys_v1.txt"), + _localDetModel = new Det(BgiOnnxModel.PaddleOcrChDet, OcrVersionConfig.PpOcrV4); + _localRecModel = new Rec(BgiOnnxModel.PaddleOcrChRec, Path.Combine(path, "ppocr_keys_v1.txt"), OcrVersionConfig.PpOcrV4); break; @@ -73,7 +74,7 @@ public class PaddleOcrService : IOcrService /// public string OcrWithoutDetector(Mat mat) { - var str = localRecModel.Run(mat).Text; + var str = _localRecModel.Run(mat).Text; Debug.WriteLine($"PaddleOcrWithoutDetector 结果: {str}"); return str; } @@ -92,7 +93,7 @@ public class PaddleOcrService : IOcrService /// private OcrResult RunAll(Mat src, int recognizeBatchSize = 0) { - var rects = localDetModel.Run(src); + var rects = _localDetModel.Run(src); Mat[] mats = rects.Select(rect => { @@ -102,7 +103,7 @@ public class PaddleOcrService : IOcrService .ToArray(); try { - return new OcrResult(localRecModel.Run(mats, recognizeBatchSize) + return new OcrResult(_localRecModel.Run(mats, recognizeBatchSize) .Select((result, i) => new OcrResultRegion(rects[i], result.Text, result.Score)) .ToArray()); } diff --git a/BetterGenshinImpact/Core/Recognition/OCR/paddle/Rec.cs b/BetterGenshinImpact/Core/Recognition/OCR/paddle/Rec.cs index f5fe1cb2..edf546fc 100644 --- a/BetterGenshinImpact/Core/Recognition/OCR/paddle/Rec.cs +++ b/BetterGenshinImpact/Core/Recognition/OCR/paddle/Rec.cs @@ -5,7 +5,8 @@ using System.IO; using System.Linq; using System.Runtime.InteropServices; using System.Text; -using BetterGenshinImpact.Core.Recognition.OCR.paddle.data; +using BetterGenshinImpact.Core.Recognition.OCR.engine; +using BetterGenshinImpact.Core.Recognition.OCR.engine.data; using BetterGenshinImpact.Core.Recognition.ONNX; using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; @@ -22,7 +23,7 @@ public class Rec public Rec(BgiOnnxModel model, string labelFilePath, OcrVersionConfig config) { _config = config; - _session = BgiOnnxFactory.Instance.CreateInferenceSession(model,true); + _session = BgiOnnxFactory.Instance.CreateInferenceSession(model, true); _labels = File.ReadAllLines(labelFilePath); @@ -95,7 +96,7 @@ public class Rec 3 => src, var x => throw new Exception($"Unexpect src channel: {x}, allow: (1/3/4)") }; - var result = OcrUtils.resize_norm_img(channel3, new OcrShape(3, maxWidth, modelHeight), + var result = OcrUtils.ResizeNormImg(channel3, new OcrShape(3, maxWidth, modelHeight), out var owner); lock (owners) { @@ -162,8 +163,10 @@ public class Rec score += (float)maxVal; sb.Append(OcrUtils.GetLabelByIndex(maxIdx[1], _labels)); } + lastIndex = maxIdx[1]; } + return new OcrRecognizerResult(sb.ToString(), score / sb.Length); }) .ToArray(); diff --git a/BetterGenshinImpact/Core/Recognition/ONNX/BgiOnnxFactory.cs b/BetterGenshinImpact/Core/Recognition/ONNX/BgiOnnxFactory.cs index b7303554..feded240 100644 --- a/BetterGenshinImpact/Core/Recognition/ONNX/BgiOnnxFactory.cs +++ b/BetterGenshinImpact/Core/Recognition/ONNX/BgiOnnxFactory.cs @@ -19,35 +19,20 @@ public class BgiOnnxFactory : Singleton { private static readonly ILogger Logger = App.GetLogger(); - - public ProviderType[] ProviderTypes { get; } - public int DmlDeviceId { get; } - public int CudaDeviceId { get; } - public bool OptimizedModel { get; } - public bool TrtUseEmbedMode { get; } - public bool EnableCache { get; } - public bool CpuOcr { get; } - /// - /// 缓存模型路径。如果一开始使用缓存就一直使用缓存文件,如果没有使用缓存就一直使用原始模型路径。 - ///
- /// 这样能避免并发加载模型问题。比如使用了未完全构建好的缓存文件,导致模型加载失败。 + /// 缓存模型路径。如果一开始使用缓存就一直使用缓存文件,如果没有使用缓存就一直使用原始模型路径。 + ///
+ /// 这样能避免并发加载模型问题。比如使用了未完全构建好的缓存文件,导致模型加载失败。 ///
- private ConcurrentDictionary _cachedModelPaths = new(); - + private readonly ConcurrentDictionary _cachedModelPaths = new(); public BgiOnnxFactory() { var config = TaskContext.Instance().Config.HardwareAccelerationConfig; - if (config.AutoAppendCudaPath) - { - AppendCudaPath(); - } + if (config.AutoAppendCudaPath) AppendCudaPath(); if (string.IsNullOrWhiteSpace(config.AdditionalPath)) - { AppendPath(config.AdditionalPath.Split(Path.PathSeparator)); - } ProviderTypes = GetProviderType(config.InferenceDevice, CudaDeviceId, DmlDeviceId); OptimizedModel = config.OptimizedModel; @@ -68,9 +53,17 @@ public class BgiOnnxFactory : Singleton CpuOcr); } + public ProviderType[] ProviderTypes { get; } + public int DmlDeviceId { get; } + public int CudaDeviceId { get; } + public bool OptimizedModel { get; } + public bool TrtUseEmbedMode { get; } + public bool EnableCache { get; } + public bool CpuOcr { get; } + /// - /// 根据InferenceDeviceType选择Provider + /// 根据InferenceDeviceType选择Provider /// /// InferenceDeviceType /// cuda设备id @@ -92,7 +85,6 @@ public class BgiOnnxFactory : Singleton SessionOptions? testSession = null; var hasGpu = false; if (!hasGpu && cudaDeviceId >= 0) - { // tensorrt本身包含cuda,设备id也是cuda的id,且比纯cuda效果好很多。 try { @@ -108,10 +100,8 @@ public class BgiOnnxFactory : Singleton { testSession?.Dispose(); } - } if (!hasGpu && dmlDeviceId >= 0) - { // dml效果不如tensorrt,但是比纯cuda稳定性强 try { @@ -128,10 +118,8 @@ public class BgiOnnxFactory : Singleton { testSession?.Dispose(); } - } if (!hasGpu && cudaDeviceId >= 0) - { // cuda优先级比较低,因为跑起来并不太理想。 try { @@ -147,12 +135,8 @@ public class BgiOnnxFactory : Singleton { testSession?.Dispose(); } - } - if (!hasGpu) - { - Logger.LogWarning("[init]GPU自动选择失败,回退到CPU处理"); - } + if (!hasGpu) Logger.LogWarning("[init]GPU自动选择失败,回退到CPU处理"); //无论如何都要加入cpu,一些计算在纯gpu上不被支持或性能很烂 list.Add(ProviderType.Cpu); @@ -163,7 +147,7 @@ public class BgiOnnxFactory : Singleton } /// - /// 自动嗅探并修改path以加载cuda + /// 自动嗅探并修改path以加载cuda /// private static void AppendCudaPath() { @@ -191,10 +175,7 @@ public class BgiOnnxFactory : Singleton { // 体系架构 var architecture = Enum.GetName(RuntimeInformation.ProcessArchitecture); - if (architecture is null) - { - return [s]; - } + if (architecture is null) return [s]; return [ @@ -219,15 +200,12 @@ public class BgiOnnxFactory : Singleton } /// - /// 将附加的path应用进来 + /// 将附加的path应用进来 /// /// 附加的path字符串 private static void AppendPath(string[] extraPath) { - if (extraPath.Length <= 0) - { - return; - } + if (extraPath.Length <= 0) return; var pathVariables = Environment.GetEnvironmentVariable("PATH", EnvironmentVariableTarget.Process) ?.Split(Path.PathSeparator).ToList() ?? new List(); @@ -244,17 +222,14 @@ public class BgiOnnxFactory : Singleton } /// - /// 根据模型创建一个YoloPredictor + /// 根据模型创建一个YoloPredictor /// /// 模型 /// BgiYoloPredictor public BgiYoloPredictor CreateYoloPredictor(BgiOnnxModel model) { Logger.LogDebug("[Yolo]创建yolo预测器,模型: {ModelName}", model.Name); - if (!EnableCache) - { - return new BgiYoloPredictor(model, model.ModalPath, CreateSessionOptions(model, false)); - } + if (!EnableCache) return new BgiYoloPredictor(model, model.ModalPath, CreateSessionOptions(model, false)); var cached = GetCached(model); return cached == null @@ -263,7 +238,7 @@ public class BgiOnnxFactory : Singleton } /// - /// 根据模型创建一个onnx运行时的InferenceSession + /// 根据模型创建一个onnx运行时的InferenceSession /// /// 模型 /// 是否是用于ocr的模型,默认false @@ -272,42 +247,33 @@ public class BgiOnnxFactory : Singleton { Logger.LogDebug("[ONNX]创建推理会话,模型: {ModelName}", model.Name); ProviderType[]? providerTypes = null; - if (CpuOcr && ocr) - { - providerTypes = [ProviderType.Cpu]; - } + if (CpuOcr && ocr) providerTypes = [ProviderType.Cpu]; if (!EnableCache) - { return new InferenceSession(model.ModalPath, CreateSessionOptions(model, false, providerTypes)); - } - var cached = GetCached(model); + var cached = GetCached(model, providerTypes); return cached == null ? new InferenceSession(model.ModalPath, CreateSessionOptions(model, true, providerTypes)) : new InferenceSession(cached, CreateSessionOptions(model, false, providerTypes)); } /// - /// 获取带有缓存的模型(目前只支持TensorRT) + /// 获取带有缓存的模型(目前只支持TensorRT) /// /// 模型 + /// 强制使用的 providerTypes /// 带有缓存的模型绝对路径,null表示尚未创建缓存 - private string? GetCached(BgiOnnxModel model) + private string? GetCached(BgiOnnxModel model, ProviderType[]? forcedProvider = null) { + var providerTypes = forcedProvider ?? ProviderTypes; // 目前只支持TensorRT - if (!ProviderTypes.Contains(ProviderType.TensorRt)) return null; + if (!providerTypes.Contains(ProviderType.TensorRt)) return null; var result = _cachedModelPaths.GetOrAdd(model, _GetCached); - if (result is null) - { - return result; - } + if (result is null) return result; // 判断文件是否存在 - if (File.Exists(result)) - { - return result; - } + if (File.Exists(result)) return result; Logger.LogWarning("[ONNX]模型 {Model} 的缓存文件可能已被删除,使用原始模型文件。", model.Name); return null; @@ -317,10 +283,8 @@ public class BgiOnnxFactory : Singleton { if (model.ModelRelativePath.StartsWith(BgiOnnxModel.ModelCacheRelativePath) && model.ModelRelativePath.EndsWith("_ctx.onnx")) - { // 这已经是带有缓存的文件路径了 return model.ModalPath; - } var ctxA = Path.Combine(model.CachePath, "trt", "_ctx.onnx"); if (File.Exists(ctxA)) @@ -343,8 +307,8 @@ public class BgiOnnxFactory : Singleton /// - /// 通过模型路径生成SessionOptions
- /// 如果加载的模型文件已经是带有缓存的模型,请将cacheFolder设为null避免重复生成。 + /// 通过模型路径生成SessionOptions
+ /// 如果加载的模型文件已经是带有缓存的模型,请将cacheFolder设为null避免重复生成。 ///
/// 模型路径 /// 是否生成缓存。有几种情况下不生成缓存:1为用户主动关闭,即enableCache为false。2为即将加载的模型文件已经是带有缓存的模型文件。 @@ -356,7 +320,6 @@ public class BgiOnnxFactory : Singleton var sessionOptions = new SessionOptions(); foreach (var type in forcedProvider is null || forcedProvider.Length == 0 ? ProviderTypes : forcedProvider) - { try { switch (type) @@ -395,23 +358,18 @@ public class BgiOnnxFactory : Singleton Logger.LogError("无法加载指定的 ONNX provider {Provider},跳过。请检查推理设备配置是否正确。({Err})", Enum.GetName(type), e.Message); } - } if (!OptimizedModel) return sessionOptions; if (!genCache) return sessionOptions; var optPath = Path.Combine(path.CachePath, "optimized"); - if (!Directory.Exists(optPath)) - { - Directory.CreateDirectory(optPath); - } - - sessionOptions.OptimizedModelFilePath = optPath; + if (!Directory.Exists(optPath)) Directory.CreateDirectory(optPath); + sessionOptions.OptimizedModelFilePath = Path.Combine(optPath, Path.GetFileName(path.ModalPath)); return sessionOptions; } /// - /// 获取TensorRT的配置 + /// 获取TensorRT的配置 /// /// 缓存生成的目录 /// trt配置 @@ -422,7 +380,7 @@ public class BgiOnnxFactory : Singleton // 不使用缓存目录 var r = new Dictionary { - ["device_id"] = CudaDeviceId.ToString(), + ["device_id"] = CudaDeviceId.ToString() }; return r; } @@ -438,7 +396,7 @@ public class BgiOnnxFactory : Singleton ["trt_timing_cache_path"] = Global.Absolute(Path.Combine(BgiOnnxModel.ModelCacheRelativePath, "trt_timing")), // ["trt_force_timing_cache"] = "1", - ["device_id"] = CudaDeviceId.ToString(), + ["device_id"] = CudaDeviceId.ToString() }; if (TrtUseEmbedMode) { @@ -451,22 +409,20 @@ public class BgiOnnxFactory : Singleton } if (!Directory.Exists(result["trt_ep_context_file_path"])) - { Directory.CreateDirectory(result["trt_ep_context_file_path"]); - } return result; } /// - /// 获取cuda provider的配置 + /// 获取cuda provider的配置 /// /// cuda配置 private Dictionary GetCudaProviderConfig() { var result = new Dictionary { - ["device_id"] = CudaDeviceId.ToString(), + ["device_id"] = CudaDeviceId.ToString() }; return result; } diff --git a/BetterGenshinImpact/Core/Recognition/ONNX/BgiYoloPredictor.cs b/BetterGenshinImpact/Core/Recognition/ONNX/BgiYoloPredictor.cs index d8db8188..c192b76e 100644 --- a/BetterGenshinImpact/Core/Recognition/ONNX/BgiYoloPredictor.cs +++ b/BetterGenshinImpact/Core/Recognition/ONNX/BgiYoloPredictor.cs @@ -77,4 +77,9 @@ public class BgiYoloPredictor : IDisposable Predictor.Dispose(); } } -} + + ~BgiYoloPredictor() + { + Dispose(); + } +} \ No newline at end of file diff --git a/BetterGenshinImpact/Core/Recognition/ONNX/SVTR/PickTextInference.cs b/BetterGenshinImpact/Core/Recognition/ONNX/SVTR/PickTextInference.cs index 38fcb3df..5e6f0ee2 100644 --- a/BetterGenshinImpact/Core/Recognition/ONNX/SVTR/PickTextInference.cs +++ b/BetterGenshinImpact/Core/Recognition/ONNX/SVTR/PickTextInference.cs @@ -9,7 +9,7 @@ using System.Diagnostics; using System.IO; using System.Text; using System.Text.Json; -using BetterGenshinImpact.Core.Recognition.OCR; +using BetterGenshinImpact.Core.Recognition.OCR.engine; namespace BetterGenshinImpact.Core.Recognition.ONNX.SVTR; diff --git a/BetterGenshinImpact/ViewModel/MainWindowViewModel.cs b/BetterGenshinImpact/ViewModel/MainWindowViewModel.cs index 865cad01..b2b4067e 100644 --- a/BetterGenshinImpact/ViewModel/MainWindowViewModel.cs +++ b/BetterGenshinImpact/ViewModel/MainWindowViewModel.cs @@ -204,7 +204,8 @@ public partial class MainWindowViewModel : ObservableObject, IViewModel // 低版本才需要迁移 if (fileVersionInfo.FileVersion != null && !Global.IsNewVersion(fileVersionInfo.FileVersion)) { - var res = await MessageBox.ShowAsync("检测到旧的 BetterGI 配置,是否迁移配置并清理旧目录?", "BetterGI", System.Windows.MessageBoxButton.YesNo, MessageBoxImage.Question); + var res = await MessageBox.ShowAsync("检测到旧的 BetterGI 配置,是否迁移配置并清理旧目录?", "BetterGI", + System.Windows.MessageBoxButton.YesNo, MessageBoxImage.Question); if (res == System.Windows.MessageBoxResult.Yes) { // 迁移配置,拷贝整个目录并覆盖 @@ -224,13 +225,14 @@ public partial class MainWindowViewModel : ObservableObject, IViewModel */ private void Patch2() { - List files =[ + List files = + [ Global.Absolute(@"Assets\Map\mainMap256Block_SIFT.kp"), Global.Absolute(@"Assets\Map\mainMap256Block_SIFT.mat"), Global.Absolute(@"Assets\Map\mainMap2048Block_SIFT.kp"), Global.Absolute(@"Assets\Map\mainMap2048Block_SIFT.mat"), ]; - + // 循环删除 foreach (var file in files.Where(File.Exists)) { @@ -246,19 +248,24 @@ public partial class MainWindowViewModel : ObservableObject, IViewModel { try { - string gameCultureInfoName = TaskContext.Instance().Config.OtherConfig.GameCultureInfoName; - await OcrFactory.ChangeCulture(gameCultureInfoName); - var s = OcrFactory.Paddle.Ocr(new Mat(Global.Absolute(@"Assets\Model\PaddleOCR\test_ocr.png"))); + // 现在OCR创建的时候会自己读设置了 + // string gameCultureInfoName = TaskContext.Instance().Config.OtherConfig.GameCultureInfoName; + // await OcrFactory.ChangeCulture(gameCultureInfoName); + var s = OcrFactory.Paddle.Ocr(new Mat(Global.Absolute(@"Assets\Model\PaddleOCR\test_pp_ocr.png"))); Debug.WriteLine("PaddleOcr预热结果:" + s); } catch (Exception e) { Console.WriteLine(e); - _logger.LogError("PaddleOcr预热异常,解决方案:【https://bettergi.com/faq.html】\r\n" + e.Source + "\r\n--" + Environment.NewLine + e.StackTrace + "\r\n---" + Environment.NewLine + e.Message); + _logger.LogError("PaddleOcr预热异常,解决方案:【https://bettergi.com/faq.html】\r\n" + e.Source + "\r\n--" + + Environment.NewLine + e.StackTrace + "\r\n---" + Environment.NewLine + e.Message); var innerException = e.InnerException; if (innerException != null) { - _logger.LogError("PaddleOcr预热内部异常,解决方案:【https://bettergi.com/faq.html】\r\n" + innerException.Source + "\r\n--" + Environment.NewLine + innerException.StackTrace + "\r\n---" + Environment.NewLine + innerException.Message); + _logger.LogError("PaddleOcr预热内部异常,解决方案:【https://bettergi.com/faq.html】\r\n" + + innerException.Source + "\r\n--" + Environment.NewLine + + innerException.StackTrace + "\r\n---" + Environment.NewLine + + innerException.Message); throw innerException; } else @@ -270,8 +277,12 @@ public partial class MainWindowViewModel : ObservableObject, IViewModel } catch (Exception e) { - MessageBox.Warning("PaddleOcr预热失败,解决方案:【https://bettergi.com/faq.html】 \r\n" + e.Source + "\r\n--" + Environment.NewLine + e.StackTrace + "\r\n---" + Environment.NewLine + e.Message); - Process.Start(new ProcessStartInfo("https://bettergi.com/faq.html#%E2%9D%93%E6%8F%90%E7%A4%BA-paddleocr%E9%A2%84%E7%83%AD%E5%A4%B1%E8%B4%A5-%E5%BA%94%E8%AF%A5%E5%A6%82%E4%BD%95%E8%A7%A3%E5%86%B3") { UseShellExecute = true }); + MessageBox.Warning("PaddleOcr预热失败,解决方案:【https://bettergi.com/faq.html】 \r\n" + e.Source + "\r\n--" + + Environment.NewLine + e.StackTrace + "\r\n---" + Environment.NewLine + e.Message); + Process.Start( + new ProcessStartInfo( + "https://bettergi.com/faq.html#%E2%9D%93%E6%8F%90%E7%A4%BA-paddleocr%E9%A2%84%E7%83%AD%E5%A4%B1%E8%B4%A5-%E5%BA%94%E8%AF%A5%E5%A6%82%E4%BD%95%E8%A7%A3%E5%86%B3") + { UseShellExecute = true }); } } @@ -290,7 +301,8 @@ public partial class MainWindowViewModel : ObservableObject, IViewModel } catch (Exception e) { - _logger.LogDebug("获取设备ID异常:" + e.Source + "\r\n--" + Environment.NewLine + e.StackTrace + "\r\n---" + Environment.NewLine + e.Message); + _logger.LogDebug("获取设备ID异常:" + e.Source + "\r\n--" + Environment.NewLine + e.StackTrace + "\r\n---" + + Environment.NewLine + e.Message); } // 每个设备只运行一次