From e9d11f726700c24dc80b5e52e3b3e9d2ee9c2f31 Mon Sep 17 00:00:00 2001 From: Takaranoao Date: Thu, 19 Feb 2026 23:08:46 -0800 Subject: [PATCH] =?UTF-8?q?=E6=96=87=E6=9C=AC=E8=AF=86=E5=88=AB=E7=9A=84?= =?UTF-8?q?=E6=A8=A1=E7=B3=8A=E5=8C=B9=E9=85=8D=E5=8A=9F=E8=83=BD=20(#2799?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore: add AGENTS.md to .gitignore * feat(config): 新增 AllowDuplicateChar OCR配置项 * refactor(ocr): Rec 暴露protected成员、提取RunInference、支持AllowDuplicateChar * feat(ocr): 打通 AllowDuplicateChar 参数链 PaddleOcrService → Rec * feat(ocr): OcrUtils 新增 CreateLabelDict/CreateWeights 工具方法 * feat(helpers): 新增 LruCache 缓存工具类 * feat(ocr): 新增 RecMatch DP模糊匹配识别器 * test(helpers): 新增 LruCache 单元测试 * test(ocr): 新增 RecMatch.GetTarget / CreateLabelDict 单元测试 * fix(ocr): 修复 RecMatch 中权重矩阵乘法的使用方式 * refactor(ocr): 合并 RecMatch 到 Rec,提取可测试静态方法,补充单元测试 将 RecMatch 子类合并到 Rec 中,消除继承关系和重复的批处理逻辑(提取 RunBatch)。 将 GetTarget 核心逻辑和 GetMaxScoreDP 提取为 OcrUtils 静态方法以便独立测试。 重命名测试文件并新增 16 个单元测试覆盖 MapStringToLabelIndices、GetMaxScoreDP、CreateWeights。 * feat(ocr): 将 Rec.RunMatch 暴露给 JS 引擎和内部 C# 代码 新增 IOcrMatchService 接口,提供基于 DP 模糊匹配的 OcrMatch/OcrMatchDirect 方法, 返回 0~1 置信度分数。PaddleOcrService 实现该接口,OcrFactory.PaddleMatch 保证 非 null 返回(引擎不支持时自动回退到普通 OCR + 编辑距离字符串比较)。 BvPage 新增 OcrMatch/WaitForOcrMatch 供 JS 脚本使用,阈值可通过配置调整。 * feat(ui): 为 OCR 配置添加允许重复字符和模糊匹配阈值的设置项 在通用设置页 OCR 配置区域新增两个控件: - 允许连续重复字符(AllowDuplicateChar)开关 - OCR模糊匹配阈值(OcrMatchDefaultThreshold)输入框 * fix: 修复 PR #2799 代码审查中发现的多项问题 - 修复 Rec.cs 空文本时 score/sb.Length 除零产生 NaN - 修复 BvPage.cs rect==default 时同一对象被双重 Dispose - 移除 Rec.cs Finalizer 避免 GC 线程加锁死锁 - 移除 CacheHelper WeakKey 无效功能,简化为直接 Dictionary 查找 - 添加 weights 数组长度与模型输出维度校验 - 修复 CreateLabelDict 空格标签索引冲突 - 修复 GetMaxScoreDP availableCount=0 除零 - 修复 OcrMatchFallbackService Contains 大小写敏感 - 修复 BvPage.cs DefaultRetryInterval=0 除零 - 添加 OcrMatchDefaultThreshold [0,1] 范围约束 - 提取 PaddleOcrService BGRA→BGR 转换辅助方法 - 使用 Interlocked.CompareExchange 修复 OcrFactory Fallback 线程安全 - 增大 LruCacheTests BuilderTest TTL 裕量避免 CI 不稳定 - 更新 .gitignore 注释 * fix: 修复 OcrMatch 归一化分母导致多区域匹配分数过低的 bug,改进 UI - 修复 GetMaxScoreFlat 中 availableCount 使用非空图像数作为分母, 导致多文字区域场景下匹配分数被过度稀释的问题,改为使用 target.Length - AllowDuplicateChar 设置项添加"需重新加载OCR引擎"的提示 - OCR模糊匹配阈值控件从 TextBox 改为 Slider + 数值显示 - 移除 Det 类中有问题的 finalizer(含锁的析构函数可能导致死锁) - 补充多区域场景的单元测试 * feat(ocr): 添加队伍切换时使用OcrMatch模糊匹配的选项和相关配置 * fix(ui): 更新匹配成功阈值默认值为 0.8 * fix(ocr): 修复队伍切换逻辑中的空值处理和优化代码结构 * refactor: 简化 LruCache,移除弱引用支持和 Builder 模式 - 移除有 TOCTOU bug 的 WeakReference 支持(且无实际使用方) - CacheItem 类改为 ValueTuple 减少堆分配 - 无过期时不再赋值 DateTime.MaxValue,过期检查短路跳过 - 移除仅剩两参数的 LruCacheBuilder,直接使用构造函数 * fix(ocr): 修复 CreateWeights 中空格字符权重写入错误索引的 bug 复用 CreateLabelDict 构建索引映射,确保空格映射到 labels.Count+1, 与 CreateLabelDict 保持一致。添加对应测试用例。 * fix(ocr): 修复 GCHandle.Alloc 失败时 finally 中 Free 掩盖原始异常的问题 * fix(ocr): 添加队伍选择按钮存在性检查,避免 PartySetupFailedException * fix(ocr): 调整 OcrMatchDefaultThreshold 的 TickFrequency 为 0.01 * fix(ocr): 修复区域裁剪逻辑,确保裁剪尺寸不为负值 * fix(ocr): 优化字符置信度提取逻辑,直接按目标字符索引查找置信度 * fix(ocr): 修正变量命名以保持一致性,调整方法名大小写 * fix(ocr): 修改 CreateWeights 方法以使用标签字典和标签计数,优化权重创建逻辑 * fix(ocr): 更新 OCR 置信度阈值设置,确保阈值范围为 0.01 到 0.99,并优化相关逻辑 --- .gitignore | 2 +- BetterGenshinImpact/Core/BgiVision/BvPage.cs | 56 ++- .../Core/Config/OtherConfig.cs | 40 ++ .../Core/Recognition/OCR/Engine/OcrUtils.cs | 123 +++++- .../Core/Recognition/OCR/IOcrMatchService.cs | 26 ++ .../Core/Recognition/OCR/OcrFactory.cs | 53 ++- .../OCR/OcrMatchFallbackService.cs | 99 +++++ .../Core/Recognition/OCR/Paddle/Det.cs | 10 +- .../OCR/Paddle/PaddleOcrService.cs | 75 +++- .../Core/Recognition/OCR/Paddle/Rec.cs | 338 +++++++++++------ .../GameTask/Common/Job/SwitchPartyTask.cs | 352 +++++++++++------- BetterGenshinImpact/Helpers/CacheHelper.cs | 103 +++++ .../View/Pages/CommonSettingsPage.xaml | 161 +++++++- .../OCRTests/OcrMatchFallbackServiceTests.cs | 195 ++++++++++ .../OCRTests/OcrUtilsTests.cs | 291 +++++++++++++++ .../HelperTests/LruCacheTests.cs | 74 ++++ 16 files changed, 1702 insertions(+), 296 deletions(-) create mode 100644 BetterGenshinImpact/Core/Recognition/OCR/IOcrMatchService.cs create mode 100644 BetterGenshinImpact/Core/Recognition/OCR/OcrMatchFallbackService.cs create mode 100644 BetterGenshinImpact/Helpers/CacheHelper.cs create mode 100644 Test/BetterGenshinImpact.UnitTest/CoreTests/RecognitionTests/OCRTests/OcrMatchFallbackServiceTests.cs create mode 100644 Test/BetterGenshinImpact.UnitTest/CoreTests/RecognitionTests/OCRTests/OcrUtilsTests.cs create mode 100644 Test/BetterGenshinImpact.UnitTest/HelperTests/LruCacheTests.cs diff --git a/.gitignore b/.gitignore index 3f896d07..97605ab9 100644 --- a/.gitignore +++ b/.gitignore @@ -28,7 +28,7 @@ github_actions_cache/ *.zip -# Rider +# IDE & AI tools .idea .trae .claude diff --git a/BetterGenshinImpact/Core/BgiVision/BvPage.cs b/BetterGenshinImpact/Core/BgiVision/BvPage.cs index 950de111..0c6ac8de 100644 --- a/BetterGenshinImpact/Core/BgiVision/BvPage.cs +++ b/BetterGenshinImpact/Core/BgiVision/BvPage.cs @@ -2,7 +2,9 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using BetterGenshinImpact.Core.Recognition; +using BetterGenshinImpact.Core.Recognition.OCR; using BetterGenshinImpact.Core.Simulator; +using BetterGenshinImpact.GameTask; using BetterGenshinImpact.GameTask.Common; using BetterGenshinImpact.GameTask.Model.Area; using Fischless.WindowsInput; @@ -118,4 +120,56 @@ public class BvPage { GameCaptureRegion.GameRegion1080PPosClick(x, y); } -} \ No newline at end of file + + /// + /// 使用模糊匹配判断截图中是否包含目标文字。 + /// 通过 自动选择最佳实现(DP 模糊匹配或普通 OCR + 字符串比较)。 + /// + /// 目标字符串 + /// 感兴趣区域,default 表示全屏 + /// 匹配阈值 (0~1),null 使用配置中的默认阈值 + /// 是否匹配成功 + public bool OcrMatch(string target, Rect rect = default, double? threshold = null) + { + var matchService = OcrFactory.PaddleMatch; + var actualThreshold = threshold + ?? TaskContext.Instance().Config.OtherConfig.OcrConfig.OcrMatchDefaultThreshold; + + var screen = TaskControl.CaptureToRectArea(); + try + { + var roi = rect == default ? null : screen.DeriveCrop(rect); + try + { + var score = matchService.OcrMatch((roi ?? screen).SrcMat, target); + return score >= actualThreshold; + } + finally + { + roi?.Dispose(); + } + } + finally + { + screen.Dispose(); + } + } + + /// + /// 重复截图并使用模糊匹配,等待目标文字出现。 + /// 超时返回 false 而非抛异常。 + /// + /// 目标字符串 + /// 感兴趣区域,default 表示全屏 + /// 匹配阈值 (0~1),null 使用配置中的默认阈值 + /// 超时时间(毫秒),null 使用 DefaultTimeout + /// 是否在超时前匹配成功 + public async Task WaitForOcrMatch(string target, Rect rect = default, double? threshold = null, int? timeout = null) + { + var actualTimeout = timeout ?? DefaultTimeout; + var retryCount = DefaultRetryInterval > 0 ? actualTimeout / DefaultRetryInterval : 1; + + return await NewRetry.WaitForAction(() => OcrMatch(target, rect, threshold), + _cancellationToken, retryCount, DefaultRetryInterval); + } +} diff --git a/BetterGenshinImpact/Core/Config/OtherConfig.cs b/BetterGenshinImpact/Core/Config/OtherConfig.cs index ac65d2b7..fccba007 100644 --- a/BetterGenshinImpact/Core/Config/OtherConfig.cs +++ b/BetterGenshinImpact/Core/Config/OtherConfig.cs @@ -106,6 +106,46 @@ public partial class OtherConfig : ObservableObject /// [ObservableProperty] private PaddleOcrModelConfig _paddleOcrModelConfig = PaddleOcrModelConfig.V4Auto; + + /// + /// 允许OCR结果中出现连续重复字符(关闭CTC重复字符折叠) + /// + [ObservableProperty] + private bool _allowDuplicateChar; + + /// + /// 切换队伍时使用 OcrMatch 模糊匹配代替正则表达式匹配 + /// + [ObservableProperty] + private bool _useOcrMatchForPartySwitch = true; + + /// + /// OcrMatch 模糊匹配的默认阈值 (0~1),分数 ≥ 阈值视为匹配成功 + /// + [ObservableProperty] + private double _ocrMatchDefaultThreshold = 0.8; + + partial void OnOcrMatchDefaultThresholdChanged(double value) + { + if (value is <= 0 or > 1) + { + OcrMatchDefaultThreshold = Math.Clamp(value, 0.01, 1); + } + } + + /// + /// PaddleOCR 识别置信度阈值 (0~1),低于此阈值的字符将被过滤 + /// + [ObservableProperty] + private double _paddleOcrThreshold = 0.5; + + partial void OnPaddleOcrThresholdChanged(double value) + { + if (value is < 0 or >= 1) + { + PaddleOcrThreshold = Math.Clamp(value, 0, 0.99); + } + } } //public partial class OtherConfig : ObservableObject diff --git a/BetterGenshinImpact/Core/Recognition/OCR/Engine/OcrUtils.cs b/BetterGenshinImpact/Core/Recognition/OCR/Engine/OcrUtils.cs index f4be7343..d82aec52 100644 --- a/BetterGenshinImpact/Core/Recognition/OCR/Engine/OcrUtils.cs +++ b/BetterGenshinImpact/Core/Recognition/OCR/Engine/OcrUtils.cs @@ -16,9 +16,9 @@ public static class OcrUtils /// 预处理速度比unsafe快5倍以上,且吃的资源还少 /// /// 输入图像,若不是灰度图会转换 - /// tensor的Memory,用完需要释放 + /// tensor的Memory,用完需要释放 /// - public static Tensor ToTensorYapDnn(Mat inputImage, out IMemoryOwner tensorMemoryOwnser) + public static Tensor ToTensorYapDnn(Mat inputImage, out IMemoryOwner tensorMemoryOwner) { using var rt = new ResourcesTracker(); Mat dst; @@ -40,10 +40,10 @@ public static class OcrUtils // 使用向量运算代替循环 var blob = rt.T(CvDnn.BlobFromImage(padded, 1.0 / 255.0, default, default, false, false)); var nCols = padded.Cols * padded.Rows; - tensorMemoryOwnser = MemoryPool.Shared.Rent(nCols); + tensorMemoryOwner = MemoryPool.Shared.Rent(nCols); // 内存复制,如果直接传指针构建的话速度还不如多复制一份 - blob.AsSpan().CopyTo(tensorMemoryOwnser.Memory.Span); - return new DenseTensor(tensorMemoryOwnser.Memory[..nCols], [1, 1, 32, 384]); + blob.AsSpan().CopyTo(tensorMemoryOwner.Memory.Span); + return new DenseTensor(tensorMemoryOwner.Memory[..nCols], [1, 1, 32, 384]); } /// @@ -180,6 +180,119 @@ public static class OcrUtils }; } + /// + /// 从标签列表构建字符串→索引字典,供 Rec 模糊匹配使用。 + /// 索引从1开始(0为CTC空白符),空格字符为 labels.Count+1。 + /// + /// 识别模型的标签列表 + /// 各标签的字符长度集合(降序排列,用于从长到短贪心匹配) + public static IReadOnlyDictionary CreateLabelDict( + IReadOnlyList labels, out int[] labelLengths) + { + var dict = new Dictionary(); + var lengths = new HashSet(); + for (var i = 0; i < labels.Count; i++) + { + if (labels[i] == " ") continue; + var len = labels[i].Length; + if (len > 0) lengths.Add(len); + dict[labels[i]] = i + 1; + } + // 空格字符对应索引 labels.Count + 1 + dict[" "] = labels.Count + 1; + lengths.Add(1); + // 降序:先尝试更长的标签 + labelLengths = lengths.OrderByDescending(x => x).ToArray(); + return dict; + } + + /// + /// 根据额外权重字典,创建与标签列表等长的权重数组(用于加权推理分数)。 + /// 未指定权重的标签默认为 1.0。 + /// + public static float[] CreateWeights( + Dictionary extraWeights, IReadOnlyDictionary labelDict, int labelCount) + { + var result = new float[labelCount + 2]; + Array.Fill(result, 1.0f); + foreach (var (key, value) in extraWeights) + { + if (!labelDict.TryGetValue(key, out var index)) continue; + if (index >= 0 && index < result.Length) + { + result[index] = value; + } + } + return result; + } + + /// + /// 将目标字符串映射为标签索引序列。 + /// 使用贪心从长到短匹配,无法映射的字符会被跳过。 + /// + /// 目标字符串 + /// 标签→索引字典(由 CreateLabelDict 生成) + /// 标签长度集合,降序排列(由 CreateLabelDict 生成) + public static int[] MapStringToLabelIndices( + string target, + IReadOnlyDictionary labelDict, + int[] labelLengths) + { + var chars = target.ToCharArray(); + var targetIndices = new int[chars.Length]; + Array.Fill(targetIndices, -1); + var index = 0; + while (index < chars.Length) + { + var found = false; + foreach (var labelLength in labelLengths) + { + if (index + labelLength > chars.Length) continue; + var subStr = new string(chars, index, labelLength); + if (!labelDict.TryGetValue(subStr, out var labelIndex)) continue; + targetIndices[index] = labelIndex; + index += labelLength; + found = true; + break; + } + if (!found) index++; + } + + return targetIndices.Where(x => x != -1).ToArray(); + } + + /// + /// 动态规划最大子序列匹配。 + /// 在 result 序列中找到 target 的最大置信度子序列匹配,返回归一化分数 (0~1)。 + /// + /// OCR 输出的 (labelIndex, confidence) 序列 + /// 目标标签索引序列 + /// 归一化分母(通常为 target.Length,得到每个目标字符的平均置信度) + public static double GetMaxScoreDp((int, float)[] result, int[] target, int availableCount) + { + if (target.Length == 0 || availableCount <= 0) return 0; + + var dp = new double[target.Length + 1]; + dp[0] = 0; + for (var j = 1; j <= target.Length; j++) + dp[j] = -255d; // 不可达 + + foreach (var (index, confidence) in result) + { + // 逆序更新,避免同一 result 元素被多次使用 + for (var j = target.Length; j >= 1; j--) + { + if (index != target[j - 1]) continue; + if (!(dp[j - 1] > -200)) continue; // 前序不可达 + var newSum = dp[j - 1] + confidence; + if (newSum > dp[j]) dp[j] = newSum; + } + } + + if (dp[target.Length] <= -200) return 0; // 无法完整匹配 + return dp[target.Length] / availableCount; + } + public static Mat Tensor2Mat(Tensor tensor) { var dimensions = tensor.Dimensions; diff --git a/BetterGenshinImpact/Core/Recognition/OCR/IOcrMatchService.cs b/BetterGenshinImpact/Core/Recognition/OCR/IOcrMatchService.cs new file mode 100644 index 00000000..a76f2fca --- /dev/null +++ b/BetterGenshinImpact/Core/Recognition/OCR/IOcrMatchService.cs @@ -0,0 +1,26 @@ +using OpenCvSharp; + +namespace BetterGenshinImpact.Core.Recognition.OCR; + +/// +/// 基于 DP 模糊匹配的 OCR 服务接口,返回匹配置信度分数 (0~1)。 +/// 独立于 IOcrService,仅由支持模糊匹配的引擎实现。 +/// +public interface IOcrMatchService +{ + /// + /// 使用检测器定位文字区域后,对每个区域进行模糊匹配,返回最高置信度 (0~1)。 + /// + /// 输入图像(推荐三通道 BGR) + /// 目标字符串 + /// 匹配置信度,0 表示完全不匹配,1 表示完全匹配 + double OcrMatch(Mat mat, string target); + + /// + /// 不使用检测器,直接对整张图像进行模糊匹配,返回置信度 (0~1)。 + /// + /// 输入图像(推荐三通道 BGR) + /// 目标字符串 + /// 匹配置信度,0 表示完全不匹配,1 表示完全匹配 + double OcrMatchDirect(Mat mat, string target); +} diff --git a/BetterGenshinImpact/Core/Recognition/OCR/OcrFactory.cs b/BetterGenshinImpact/Core/Recognition/OCR/OcrFactory.cs index 6ff3cdaf..7e6b49d7 100644 --- a/BetterGenshinImpact/Core/Recognition/OCR/OcrFactory.cs +++ b/BetterGenshinImpact/Core/Recognition/OCR/OcrFactory.cs @@ -1,5 +1,6 @@ using System; using System.Globalization; +using System.Threading; using System.Threading.Tasks; using BetterGenshinImpact.Core.Config; using BetterGenshinImpact.Core.Recognition.OCR.Paddle; @@ -18,7 +19,27 @@ public class OcrFactory : IDisposable public static IOcrService Paddle => App.ServiceProvider.GetRequiredService().PaddleOcr; private IOcrService PaddleOcr => _paddleOcrService ??= Create(OcrEngineTypes.Paddle); + /// + /// 获取支持模糊匹配的 OCR 服务。 + /// 若引擎原生支持 IOcrMatchService 则直接返回,否则回退到普通 OCR + 字符串相似度。 + /// 访问此属性会触发 Paddle 引擎的懒加载。 + /// + public static IOcrMatchService PaddleMatch + { + get + { + var factory = App.ServiceProvider.GetRequiredService(); + var service = factory.PaddleOcr; + if (service is IOcrMatchService matchService) + return matchService; + var fallback = new OcrMatchFallbackService(service); + return Interlocked.CompareExchange(ref factory._paddleOcrMatchFallback, fallback, null) + ?? fallback; + } + } + private IOcrService? _paddleOcrService; + private IOcrMatchService? _paddleOcrMatchFallback; private readonly ILogger _logger; private readonly OtherConfig.Ocr _config; @@ -87,34 +108,33 @@ public class OcrFactory : IDisposable private PaddleOcrService CreatePaddleOcrInstance() { + var allowDuplicateChar = _config.AllowDuplicateChar; + var threshold = (float)_config.PaddleOcrThreshold; + var factory = App.ServiceProvider.GetRequiredService(); return _config.PaddleOcrModelConfig switch { PaddleOcrModelConfig.V4Auto => - new PaddleOcrService(App.ServiceProvider.GetRequiredService(), + new PaddleOcrService(factory, PaddleOcrService.PaddleOcrModelType.FromCultureInfoV4(GetCultureInfo()) ?? - PaddleOcrService.PaddleOcrModelType.V4), + PaddleOcrService.PaddleOcrModelType.V4, + allowDuplicateChar, threshold), PaddleOcrModelConfig.V5Auto => - new PaddleOcrService(App.ServiceProvider.GetRequiredService(), + new PaddleOcrService(factory, PaddleOcrService.PaddleOcrModelType.FromCultureInfo(GetCultureInfo()) ?? - PaddleOcrService.PaddleOcrModelType.V5), + PaddleOcrService.PaddleOcrModelType.V5, + allowDuplicateChar, threshold), PaddleOcrModelConfig.V5 => - new PaddleOcrService(App.ServiceProvider.GetRequiredService(), - PaddleOcrService.PaddleOcrModelType.V5), + new PaddleOcrService(factory, PaddleOcrService.PaddleOcrModelType.V5, allowDuplicateChar, threshold), PaddleOcrModelConfig.V4 => - new PaddleOcrService(App.ServiceProvider.GetRequiredService(), - PaddleOcrService.PaddleOcrModelType.V4), + new PaddleOcrService(factory, PaddleOcrService.PaddleOcrModelType.V4, allowDuplicateChar, threshold), PaddleOcrModelConfig.V4En => - new PaddleOcrService(App.ServiceProvider.GetRequiredService(), - PaddleOcrService.PaddleOcrModelType.V4En), + new PaddleOcrService(factory, PaddleOcrService.PaddleOcrModelType.V4En, allowDuplicateChar, threshold), PaddleOcrModelConfig.V5Korean => - new PaddleOcrService(App.ServiceProvider.GetRequiredService(), - PaddleOcrService.PaddleOcrModelType.V5Korean), + new PaddleOcrService(factory, PaddleOcrService.PaddleOcrModelType.V5Korean, allowDuplicateChar, threshold), PaddleOcrModelConfig.V5Latin => - new PaddleOcrService(App.ServiceProvider.GetRequiredService(), - PaddleOcrService.PaddleOcrModelType.V5Latin), + new PaddleOcrService(factory, PaddleOcrService.PaddleOcrModelType.V5Latin, allowDuplicateChar, threshold), PaddleOcrModelConfig.V5Eslav => - new PaddleOcrService(App.ServiceProvider.GetRequiredService(), - PaddleOcrService.PaddleOcrModelType.V5Eslav), + new PaddleOcrService(factory, PaddleOcrService.PaddleOcrModelType.V5Eslav, allowDuplicateChar, threshold), _ => throw new ArgumentOutOfRangeException(nameof(_config.PaddleOcrModelConfig), _config.PaddleOcrModelConfig, "不支持的 Paddle OCR 模型配置") }; @@ -123,6 +143,7 @@ public class OcrFactory : IDisposable public Task Unload() { + _paddleOcrMatchFallback = null; if (_paddleOcrService is not IDisposable disposable) { _paddleOcrService = null; diff --git a/BetterGenshinImpact/Core/Recognition/OCR/OcrMatchFallbackService.cs b/BetterGenshinImpact/Core/Recognition/OCR/OcrMatchFallbackService.cs new file mode 100644 index 00000000..9d38d39d --- /dev/null +++ b/BetterGenshinImpact/Core/Recognition/OCR/OcrMatchFallbackService.cs @@ -0,0 +1,99 @@ +using System; +using System.Diagnostics; +using OpenCvSharp; + +namespace BetterGenshinImpact.Core.Recognition.OCR; + +/// +/// 当 OCR 引擎不支持 IOcrMatchService 时的回退实现。 +/// 使用普通 OCR 识别文字后,通过字符串相似度进行匹配。 +/// +public class OcrMatchFallbackService : IOcrMatchService +{ + private readonly IOcrService _ocrService; + + public OcrMatchFallbackService(IOcrService ocrService) + { + _ocrService = ocrService; + } + + public double OcrMatch(Mat mat, string target) + { + var startTime = Stopwatch.GetTimestamp(); + var ocrResult = _ocrService.OcrResult(mat); + var score = ComputeBestTextSimilarity(ocrResult, target); + var time = Stopwatch.GetElapsedTime(startTime); + Debug.WriteLine($"OcrMatchFallback 耗时 {time.TotalMilliseconds}ms 目标: {target} 分数: {score:F4}"); + return score; + } + + public double OcrMatchDirect(Mat mat, string target) + { + var startTime = Stopwatch.GetTimestamp(); + var text = _ocrService.OcrWithoutDetector(mat); + var score = ComputeTextSimilarity(text, target); + var time = Stopwatch.GetElapsedTime(startTime); + Debug.WriteLine($"OcrMatchDirectFallback 耗时 {time.TotalMilliseconds}ms 目标: {target} 分数: {score:F4}"); + return score; + } + + /// + /// 在 OCR 结果的所有区域中找到与目标字符串最相似的分数。 + /// + private static double ComputeBestTextSimilarity(OcrResult ocrResult, string target) + { + double bestScore = 0; + foreach (var region in ocrResult.Regions) + { + var score = ComputeTextSimilarity(region.Text, target); + if (score > bestScore) bestScore = score; + if (score >= 1.0) break; + } + + return bestScore; + } + + /// + /// 计算两个字符串的相似度 (0~1)。 + /// 优先检查子串包含关系,否则使用编辑距离计算。 + /// + public static double ComputeTextSimilarity(string text, string target) + { + if (string.IsNullOrEmpty(target)) return 1.0; + if (string.IsNullOrEmpty(text)) return 0.0; + if (text.Contains(target, StringComparison.OrdinalIgnoreCase)) return 1.0; + if (target.Contains(text, StringComparison.OrdinalIgnoreCase)) return (double)text.Length / target.Length; + + var distance = LevenshteinDistance(text, target); + var maxLen = Math.Max(text.Length, target.Length); + return 1.0 - (double)distance / maxLen; + } + + /// + /// 计算两个字符串之间的编辑距离(Levenshtein Distance)。 + /// + public static int LevenshteinDistance(string s, string t) + { + var sLen = s.Length; + var tLen = t.Length; + var prev = new int[tLen + 1]; + var curr = new int[tLen + 1]; + + for (var j = 0; j <= tLen; j++) + prev[j] = j; + + for (var i = 1; i <= sLen; i++) + { + curr[0] = i; + for (var j = 1; j <= tLen; j++) + { + var cost = s[i - 1] == t[j - 1] ? 0 : 1; + curr[j] = Math.Min(Math.Min(curr[j - 1] + 1, prev[j] + 1), prev[j - 1] + cost); + } + + (prev, curr) = (curr, prev); + } + + return prev[tLen]; + } +} diff --git a/BetterGenshinImpact/Core/Recognition/OCR/Paddle/Det.cs b/BetterGenshinImpact/Core/Recognition/OCR/Paddle/Det.cs index 6c2c25db..b49a4ca2 100644 --- a/BetterGenshinImpact/Core/Recognition/OCR/Paddle/Det.cs +++ b/BetterGenshinImpact/Core/Recognition/OCR/Paddle/Det.cs @@ -30,15 +30,7 @@ public class Det(BgiOnnxModel model, OcrVersionConfig config, BgiOnnxFactory bgi /// Gets or sets the ratio for enlarging text boxes during post-processing. public float UnclipRatio { get; set; } = 2.0f; - - ~Det() - { - lock (_session) - { - _session.Dispose(); - } - } - + public void Dispose() { lock (_session) diff --git a/BetterGenshinImpact/Core/Recognition/OCR/Paddle/PaddleOcrService.cs b/BetterGenshinImpact/Core/Recognition/OCR/Paddle/PaddleOcrService.cs index fa8639c7..c54b5797 100644 --- a/BetterGenshinImpact/Core/Recognition/OCR/Paddle/PaddleOcrService.cs +++ b/BetterGenshinImpact/Core/Recognition/OCR/Paddle/PaddleOcrService.cs @@ -15,7 +15,7 @@ using Size = OpenCvSharp.Size; namespace BetterGenshinImpact.Core.Recognition.OCR.Paddle; -public class PaddleOcrService : IOcrService, IDisposable +public class PaddleOcrService : IOcrService, IOcrMatchService, IDisposable { /// /// Usage: @@ -103,11 +103,11 @@ public class PaddleOcrService : IOcrService, IDisposable TestImagePath); } - public (Det, Rec) Build(BgiOnnxFactory onnxFactory) + public (Det, Rec) Build(BgiOnnxFactory onnxFactory, bool allowDuplicateChar = false, float threshold = 0.5f) { return ( new Det(DetectionModel, DetectionVersion, onnxFactory), - new Rec(RecognitionModel, RecLabel(), RecognitionVersion, onnxFactory)); + new Rec(RecognitionModel, RecLabel(), RecognitionVersion, onnxFactory, allowDuplicateChar, threshold: threshold)); } public static readonly PaddleOcrModelType V4 = Create( @@ -239,9 +239,10 @@ public class PaddleOcrService : IOcrService, IDisposable } } - public PaddleOcrService(BgiOnnxFactory bgiOnnxFactory, PaddleOcrModelType modelType) + public PaddleOcrService(BgiOnnxFactory bgiOnnxFactory, PaddleOcrModelType modelType, + bool allowDuplicateChar = false, float threshold = 0.5f) { - var (modelsDet, modelsRec) = modelType.Build(bgiOnnxFactory); + var (modelsDet, modelsRec) = modelType.Build(bgiOnnxFactory, allowDuplicateChar, threshold); _localDetModel = modelsDet; _localRecModel = modelsRec; @@ -267,13 +268,8 @@ public class PaddleOcrService : IOcrService, IDisposable /// public OcrResult OcrResult(Mat mat) { - if (mat.Channels() == 4) - { - using var mat3 = mat.CvtColor(ColorConversionCodes.BGRA2BGR); - return _OcrResult(mat3); - } - - return _OcrResult(mat); + using var converted = ConvertBgrIfNeeded(mat); + return _OcrResult(converted ?? mat); } /// @@ -338,6 +334,61 @@ public class PaddleOcrService : IOcrService, IDisposable Math.Clamp(rect.Bottom, 0, size.Height)); } + /// + /// 若输入为 BGRA 则转换为 BGR,否则返回 null。 + /// 调用方需在使用后 Dispose 返回的 Mat(若非 null)。 + /// + private static Mat? ConvertBgrIfNeeded(Mat mat) + { + return mat.Channels() == 4 ? mat.CvtColor(ColorConversionCodes.BGRA2BGR) : null; + } + + /// + /// 使用检测器定位文字区域后,对每个区域进行 DP 模糊匹配,返回最高置信度 (0~1)。 + /// + public double OcrMatch(Mat mat, string target) + { + var startTime = Stopwatch.GetTimestamp(); + + using var src = ConvertBgrIfNeeded(mat); + var bgr = src ?? mat; + + var rects = _localDetModel.Run(bgr); + Mat[] mats = rects.Select(rect => + { + var roi = bgr[GetCropedRect(rect.BoundingRect(), bgr.Size())]; + return roi; + }).ToArray(); + + try + { + var score = _localRecModel.RunMatch(mats, target); + var time = Stopwatch.GetElapsedTime(startTime); + Debug.WriteLine($"PaddleOcrMatch 耗时 {time.TotalMilliseconds}ms 目标: {target} 分数: {score:F4}"); + return score; + } + finally + { + foreach (var m in mats) m.Dispose(); + } + } + + /// + /// 不使用检测器,直接对整张图像进行 DP 模糊匹配,返回置信度 (0~1)。 + /// + public double OcrMatchDirect(Mat mat, string target) + { + var startTime = Stopwatch.GetTimestamp(); + + using var src = ConvertBgrIfNeeded(mat); + var bgr = src ?? mat; + + var score = _localRecModel.RunMatch([bgr], target); + var time = Stopwatch.GetElapsedTime(startTime); + Debug.WriteLine($"PaddleOcrMatchDirect 耗时 {time.TotalMilliseconds}ms 目标: {target} 分数: {score:F4}"); + return score; + } + public void Dispose() { _localDetModel.Dispose(); diff --git a/BetterGenshinImpact/Core/Recognition/OCR/Paddle/Rec.cs b/BetterGenshinImpact/Core/Recognition/OCR/Paddle/Rec.cs index e6641b7c..68d2a1d2 100644 --- a/BetterGenshinImpact/Core/Recognition/OCR/Paddle/Rec.cs +++ b/BetterGenshinImpact/Core/Recognition/OCR/Paddle/Rec.cs @@ -7,22 +7,66 @@ using System.Text; using BetterGenshinImpact.Core.Recognition.OCR.Engine; using BetterGenshinImpact.Core.Recognition.OCR.Engine.data; using BetterGenshinImpact.Core.Recognition.ONNX; +using BetterGenshinImpact.Helpers; using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; using OpenCvSharp; namespace BetterGenshinImpact.Core.Recognition.OCR.Paddle; -public class Rec( - BgiOnnxModel model, - IReadOnlyList labels, - OcrVersionConfig config, - BgiOnnxFactory bgiOnnxFactory) - : IDisposable +/// +/// OCR 识别器,支持标准文字识别和基于动态规划的模糊匹配。 +/// 模糊匹配将目标字符串与模型原始输出序列做子序列匹配,返回 0~1 的置信度分数, +/// 比先识别再字符串匹配更能容忍 OCR 噪声。 +/// +public class Rec : IDisposable { - private readonly InferenceSession _session = bgiOnnxFactory.CreateInferenceSession(model, true); + private readonly InferenceSession _session; + private readonly IReadOnlyList _labels; + private readonly OcrVersionConfig _config; + private readonly bool _allowDuplicateChar; + private readonly float _threshold; - // _labels = File.ReadAllLines(labelFilePath); + // 模糊匹配相关字段 + + /// 标签长度集合(降序),用于从长到短贪心匹配目标字符串 + private readonly int[] _labelLengths; + + /// 标签字符串→索引字典,索引从1开始(0为CTC空白符) + private readonly IReadOnlyDictionary _labelDict; + + /// 按标签索引的权重数组,用于加权推理分数;为 null 时不加权 + private readonly float[]? _weights; + + /// 目标字符串→标签索引序列的 LRU 缓存,加速重复查询 + private readonly CacheHelper.LruCache _targetCache = new(128); + + /// + /// ONNX 推理输出的命名张量结构,替代匿名元组 (int[], float[])。 + /// + private readonly record struct TensorResult(int Batch, int TimeSteps, int LabelCount, float[] Data); + + public Rec( + BgiOnnxModel model, + IReadOnlyList labels, + OcrVersionConfig config, + BgiOnnxFactory bgiOnnxFactory, + bool allowDuplicateChar = false, + Dictionary? extraWeights = null, + float threshold = 0.5f) + { + _session = bgiOnnxFactory.CreateInferenceSession(model, true); + _labels = labels; + _config = config; + _allowDuplicateChar = allowDuplicateChar; + _threshold = threshold; + + _labelDict = OcrUtils.CreateLabelDict(labels, out var labelLengths); + _labelLengths = labelLengths; + _weights = extraWeights is { Count: > 0 } + ? OcrUtils.CreateWeights(extraWeights, _labelDict, labels.Count) + : null; + } public void Dispose() { @@ -33,42 +77,14 @@ public class Rec( GC.SuppressFinalize(this); } - - ~Rec() - { - lock (_session) - { - _session.Dispose(); - } - } - /// - /// Run OCR recognition on multiple images in batches. + /// 对多张图像按批次执行 OCR 识别。 /// - /// Array of images for OCR recognition. - /// Size of the batch to run OCR recognition on. - /// Array of instances corresponding to OCR recognition results of the images. public OcrRecognizerResult[] Run(Mat[] srcs, int batchSize = 0) - { - if (srcs.Length == 0) return []; - - var chooseBatchSize = batchSize != 0 ? batchSize : Math.Min(8, Environment.ProcessorCount); - - return srcs - .Select((x, i) => (mat: x, i)) - .OrderBy(x => x.mat.Width) - .Chunk(chooseBatchSize) - .Select(x => (result: RunMulti(x.Select(x1 => x1.mat).ToArray()), ids: x.Select(x1 => x1.i).ToArray())) - .SelectMany(x => x.result.Zip(x.ids, (result, i) => (result, i))) - .OrderBy(x => x.i) - .Select(x => x.result) - .ToArray(); - } + => RunBatch(srcs, RunMulti, batchSize); public OcrRecognizerResult Run(Mat src) - { - return RunMulti([src]).Single(); - } + => RunMulti([src]).Single(); private OcrRecognizerResult[] RunMulti(Mat[] srcs) { @@ -81,17 +97,173 @@ public class Rec( throw new ArgumentException($"src[{i}] size should not be 0, wrong input picture provided?"); } - var modelHeight = config.Shape.Height; + var resultTensors = RunInference(srcs); + + return resultTensors.SelectMany(tensor => + { + GCHandle dataHandle = default; + try + { + dataHandle = GCHandle.Alloc(tensor.Data, GCHandleType.Pinned); + var dataPtr = dataHandle.AddrOfPinnedObject(); + + return Enumerable.Range(0, tensor.Batch) + .Select(i => + { + StringBuilder sb = new(); + var lastIndex = 0; + float score = 0; + var maxIdx = new int[2]; + using var fullMat = Mat.FromPixelData(tensor.TimeSteps, tensor.LabelCount, + MatType.CV_32FC1, + dataPtr + i * tensor.TimeSteps * tensor.LabelCount * sizeof(float)); + for (var n = 0; n < tensor.TimeSteps; ++n) + { + using var row = fullMat.Row(n); + row.MinMaxIdx(out _, out var maxVal, [], maxIdx); + + if (maxIdx[1] > 0 && maxVal >= _threshold && (_allowDuplicateChar || !(n > 0 && maxIdx[1] == lastIndex))) + { + score += (float)maxVal; + sb.Append(OcrUtils.GetLabelByIndex(maxIdx[1], _labels)); + } + + lastIndex = maxIdx[1]; + } + + var text = sb.ToString(); + return new OcrRecognizerResult(text, text.Length > 0 ? score / text.Length : 0); + }) + .ToArray(); + } + finally + { + if (dataHandle.IsAllocated) dataHandle.Free(); + } + }).ToArray(); + } + + /// + /// 将目标字符串转换为标签索引序列,利用 LRU 缓存加速重复查询。 + /// 无法映射到标签的字符会被跳过。 + /// + public int[] GetTarget(string target) + { + if (_targetCache.TryGet(target, out var cached) && cached is not null) + return cached; + + var result = OcrUtils.MapStringToLabelIndices(target, _labelDict, _labelLengths); + _targetCache.Set(target, result); + return result; + } + + /// + /// 对一批图像执行模糊匹配,返回与目标字符串的最大平均置信度 (0~1)。 + /// + /// 待匹配图像数组 + /// 目标字符串 + /// 每批推理图像数,0表示自动 + public double RunMatch(Mat[] srcs, string target, int batchSize = 0) + { + if (srcs.Length == 0) return 0; + var targetIndexes = GetTarget(target); + if (targetIndexes.Length == 0) return 0; + + var charLevelResults = RunBatch(srcs, + mats => ProcessForMatch(RunInference(mats), targetIndexes), batchSize); + + return GetMaxScoreFlat(charLevelResults, targetIndexes); + } + + /// + /// 从 ONNX 原始输出张量中提取目标字符在每个时间步的置信度。 + /// + /// 与 RunMulti(标准 OCR)不同,此方法不做 argmax(MinMaxIdx), + /// 而是按目标字符的 label 索引直接查找对应位置的原始置信度。 + /// 这样即使目标字符不是某个时间步的最高置信度候选,DP 仍然能拿到其实际分数进行匹配。 + /// + /// + /// RunInference 返回的 (shape, data) 张量数组 + /// 目标字符串映射后的 label 索引序列 + /// 每张图像对应一个 (labelIndex, confidence) 数组,供 DP 匹配使用 + private (int, float)[][] ProcessForMatch(TensorResult[] resultTensors, int[] targetIndexes) + { + // 目标字符去重(排除 CTC 空白符 index=0) + var targetSet = new HashSet(targetIndexes); + targetSet.Remove(0); + + return resultTensors.Select(tensor => + { + var chars = new List<(int, float)>(); + for (var n = 0; n < tensor.TimeSteps * tensor.Batch; n++) + { + // 直接按索引查找目标字符的置信度,而非对整行取 argmax + var rowOffset = n * tensor.LabelCount; + foreach (var labelIdx in targetSet) + { + if (labelIdx >= tensor.LabelCount) continue; + var raw = tensor.Data[rowOffset + labelIdx]; + var confidence = _weights is not null + ? raw * _weights[labelIdx] + : raw; + if (confidence > _threshold) + chars.Add((labelIdx, confidence)); + } + } + return chars.ToArray(); + }).ToArray(); + } + + /// + /// 将多张图像的字符级别结果展平后,计算与 target 的最大匹配分数。 + /// 分母使用 target.Length,得到的是每个目标字符的平均置信度 (0~1)。 + /// + private static double GetMaxScoreFlat((int, float)[][] result, int[] target) + { + var flatResult = result.SelectMany(x => x).ToArray(); + return OcrUtils.GetMaxScoreDp(flatResult, target, target.Length); + } + + /// + /// 通用批处理:按宽度排序、分批推理、恢复原始顺序 + /// + private T[] RunBatch(Mat[] srcs, Func process, int batchSize = 0) + { + if (srcs.Length == 0) return []; + + var chooseBatchSize = batchSize != 0 ? batchSize : Math.Min(8, Environment.ProcessorCount); + + return srcs + .Select((x, i) => (mat: x, i)) + .OrderBy(x => x.mat.Width) + .Chunk(chooseBatchSize) + .Select(chunk => + { + var mats = chunk.Select(x => x.mat).ToArray(); + var result = process(mats); + return (result, ids: chunk.Select(x => x.i).ToArray()); + }) + .SelectMany(x => x.result.Zip(x.ids, (r, i) => (r, i))) + .OrderBy(x => x.i) + .Select(x => x.r) + .ToArray(); + } + + /// + /// 执行 ONNX 推理,返回每张图像的原始 (shape, data) 张量 + /// + private TensorResult[] RunInference(Mat[] srcs) + { + var modelHeight = _config.Shape.Height; var maxWidth = (int)Math.Ceiling(srcs.Max(src => { var size = src.Size(); return 1.0 * size.Width / size.Height * modelHeight; })); List> owners = []; - (int[], float[])[] resultTensors; try { - resultTensors = srcs + return srcs // .AsParallel() .Select(src => { @@ -111,12 +283,10 @@ public class Rec( { owners.Add(owner); } - return result; } finally { - // Only dispose Mats created in this scope if (channel3 != null && !ReferenceEquals(channel3, src)) { channel3.Dispose(); @@ -124,75 +294,31 @@ public class Rec( } }) .Select(inputTensor => + { + lock (_session) { - lock (_session) - { - // 多线程推理会出现问题,加锁解决。 - using IDisposableReadOnlyCollection results = _session.Run([ - NamedOnnxValue.CreateFromTensor(_session.InputNames[0], inputTensor) - ]); - var output = results[0]; - if (output.ElementType is not TensorElementType.Float) - throw new Exception($"Unexpected output tensor type: {output.ElementType}"); + // 多线程推理会出现问题,加锁解决。 + using IDisposableReadOnlyCollection results = _session.Run([ + NamedOnnxValue.CreateFromTensor(_session.InputNames[0], inputTensor) + ]); + var output = results[0]; + if (output.ElementType is not TensorElementType.Float) + throw new Exception($"Unexpected output tensor type: {output.ElementType}"); - if (output.ValueType is not OnnxValueType.ONNX_TYPE_TENSOR) - throw new Exception($"Unexpected output tensor value type: {output.ValueType}"); - var tensor = output.AsTensor(); - // 因为一个已知bug,tensor中内存在dml下使用完后会被释放掉,锁之外的代码会报错 - return (tensor.Dimensions.ToArray(), tensor.ToArray()); - } + if (output.ValueType is not OnnxValueType.ONNX_TYPE_TENSOR) + throw new Exception($"Unexpected output tensor value type: {output.ValueType}"); + var tensor = output.AsTensor(); + // 因为一个已知bug,tensor中内存在dml下使用完后会被释放掉,锁之外的代码会报错 + var dims = tensor.Dimensions; + return new TensorResult(dims[0], dims[1], dims[2], tensor.ToArray()); } - ).ToArray(); + }).ToArray(); } finally { owners.ForEach(x => { x.Dispose(); }); } - - return resultTensors.SelectMany(resultTensor => - { - var resultArray = resultTensor.Item2; - var resultShape = resultTensor.Item1; - GCHandle dataHandle = default; - try - { - dataHandle = GCHandle.Alloc(resultArray, GCHandleType.Pinned); - var dataPtr = dataHandle.AddrOfPinnedObject(); - var labelCount = resultShape[2]; - var charCount = resultShape[1]; - - return Enumerable.Range(0, resultShape[0]) - .Select(i => - { - StringBuilder sb = new(); - var lastIndex = 0; - float score = 0; - for (var n = 0; n < charCount; ++n) - { - using var mat = Mat.FromPixelData(1, labelCount, MatType.CV_32FC1, - dataPtr + (n + i * charCount) * labelCount * sizeof(float)); - var maxIdx = new int[2]; - mat.MinMaxIdx(out _, out var maxVal, [], maxIdx); - - if (maxIdx[1] > 0 && !(n > 0 && maxIdx[1] == lastIndex)) - { - score += (float)maxVal; - sb.Append(OcrUtils.GetLabelByIndex(maxIdx[1], labels)); - } - - lastIndex = maxIdx[1]; - } - - return new OcrRecognizerResult(sb.ToString(), score / sb.Length); - }) - .ToArray(); - } - finally - { - dataHandle.Free(); - } - }).ToArray(); } - public string GetConfigName => config.Name; -} \ No newline at end of file + public string GetConfigName => _config.Name; +} diff --git a/BetterGenshinImpact/GameTask/Common/Job/SwitchPartyTask.cs b/BetterGenshinImpact/GameTask/Common/Job/SwitchPartyTask.cs index 8d1d9cd9..e7b0d730 100644 --- a/BetterGenshinImpact/GameTask/Common/Job/SwitchPartyTask.cs +++ b/BetterGenshinImpact/GameTask/Common/Job/SwitchPartyTask.cs @@ -1,4 +1,5 @@ using BetterGenshinImpact.Core.Recognition; +using BetterGenshinImpact.Core.Recognition.OCR; using BetterGenshinImpact.Core.Simulator; using BetterGenshinImpact.Core.Simulator.Extensions; using BetterGenshinImpact.GameTask.Common.BgiVision; @@ -9,6 +10,7 @@ using BetterGenshinImpact.View.Drawable; using Microsoft.Extensions.Logging; using OpenCvSharp; using System; +using System.Collections.Generic; using System.Linq; using System.Text.RegularExpressions; using System.Threading; @@ -28,56 +30,17 @@ public class SwitchPartyTask public async Task Start(string partyName, CancellationToken ct) { - bool isInPartyViewUi = false; + var useOcrMatch = TaskContext.Instance().Config.OtherConfig.OcrConfig.UseOcrMatchForPartySwitch; Logger.LogInformation("尝试切换至队伍: {Name}", partyName); using var ra1 = CaptureToRectArea(); + // 确保进入队伍配置界面 + bool isInPartyViewUi = false; if (!Bv.IsInPartyViewUi(ra1)) { isInPartyViewUi = true; - // 如果不在主界面,则返回主界面 - if (!Bv.IsInMainUi(ra1)) - { - await _returnMainUiTask.Start(ct); - await Delay(200, ct); - using var raAfterMain = CaptureToRectArea(); - if (!Bv.IsInMainUi(raAfterMain)) - { - throw new InvalidOperationException("未能返回主界面"); - } - } - - // 尝试打开队伍配置页面 - const int maxAttempts = 2; - bool isOpened = false; - for (int attempt = 1; attempt <= maxAttempts; attempt++) - { - Simulation.SendInput.SimulateAction(GIActions.OpenPartySetupScreen); - - // 考虑加载时间 2s,共检查 4.2s,如果失败则抛出异常 - - for (int i = 0; i < 7; i++) // 检查 7 次 - { - await Delay(600, ct); - using var raCheck = CaptureToRectArea(); - if (Bv.IsInPartyViewUi(raCheck)) - { - isOpened = true; - break; - } - } - - if (isOpened) - { - break; // 页面已打开,跳出循环 - } - } - - if (!isOpened) - { - throw new PartySetupFailedException("未能打开队伍配置界面"); - } + await EnsurePartyViewOpen(ra1, ct); } await Delay(500, ct); @@ -85,33 +48,15 @@ public class SwitchPartyTask using var ra = CaptureToRectArea(); var partyViewBtn = ra.Find(ElementAssets.Instance.PartyBtnChooseView); - // OCR 当前队伍名称(无法单字,中间禁止空格) - var currTeamName = ra.Find(new RecognitionObject + if (!partyViewBtn.IsExist()) { - RecognitionType = RecognitionTypes.Ocr, - RegionOfInterest = new Rect(partyViewBtn.Right, partyViewBtn.Top, (int)(350 * _assetScale), - partyViewBtn.Height) - }).Text; - - var tempName = currTeamName - .Replace("\"", "") // 移除所有双引号(核心新增,解决日志里的""问题) - .Replace("\r\n", "") // 清理Windows换行符 - .Replace("\r", ""); // 先清理所有双引号,避免引号干扰后续处理 - - // 核心逻辑:找到第一个换行符(\n)的位置,截断并删除换行+后面所有字符 - int firstNewLineIndex = tempName.IndexOf('\n'); - if (firstNewLineIndex != -1) // 存在换行符,截取到换行符前 - { - tempName = tempName.Substring(0, firstNewLineIndex); + Logger.LogWarning("未找到队伍选择按钮,无法判断当前队伍"); + throw new PartySetupFailedException("未找到队伍选择按钮"); } - - // 最后统一去首尾所有空白(空格、制表符、回车符\r等),得到纯净队伍名 - currTeamName = tempName.Trim(); - Logger.LogInformation("切换队伍,当前队伍名称: {Text},使用正则表达式规则进行模糊匹配", currTeamName); - if (Regex.IsMatch(currTeamName, partyName)) + // 检查当前队伍是否已是目标 + if (IsCurrentTeamMatch(ra, partyViewBtn, partyName, useOcrMatch)) { - Logger.LogInformation("当前队伍[{Name}]即为目标队伍,无需切换", currTeamName); if (isInPartyViewUi) { Simulation.SendInput.Keyboard.KeyPress(User32.VK.VK_ESCAPE); @@ -122,101 +67,76 @@ public class SwitchPartyTask return true; } - var menu = await NewRetry.WaitForElementAppear( - ElementAssets.Instance.PartyBtnDelete, - () => partyViewBtn.Click(),// 点击队伍选择按钮 - ct, - 4, - 500 - ); - if (!menu) - { - throw new PartySetupFailedException("未能打开队伍选择页面"); - } + // 打开队伍选择页面 + var partyDeleteBtn = await OpenPartyChoosePage(partyViewBtn, ct); + await ScrollToTop(ct); - ImageRegion? switchRa = null; - Region? partyDeleteBtn = null; - using (var ocrRa = CaptureToRectArea()) - { - var openPartyChooseSuccess = await NewRetry.WaitForAction(() => - { - switchRa = ocrRa; - partyDeleteBtn = switchRa.Find(ElementAssets.Instance.PartyBtnDelete); - return partyDeleteBtn.IsExist(); - }, ct, 5); - - if (!openPartyChooseSuccess || switchRa == null || partyDeleteBtn == null) - { - throw new PartySetupFailedException("未能打开队伍配置界面"); - } - } - - // 点击到最上方 - await Task.Delay(50, ct); - GameCaptureRegion.GameRegion1080PPosClick(700, 125); - await Task.Delay(50, ct); - Simulation.SendInput.Mouse.LeftButtonDown(); - await Task.Delay(450, ct); - Simulation.SendInput.Mouse.LeftButtonUp(); - await Task.Delay(100, ct); - - Rect regionOfInterest = new Rect(0, (int)(80 * _assetScale), partyDeleteBtn.Right, partyDeleteBtn.Top - (int)(80 * _assetScale)); - RecognitionObject recognitionObject = new RecognitionObject + // 逐页查找目标队伍 + Rect regionOfInterest = new(0, (int)(80 * _assetScale), partyDeleteBtn.Right, partyDeleteBtn.Top - (int)(80 * _assetScale)); + var recognitionObject = new RecognitionObject { RecognitionType = RecognitionTypes.Ocr, RegionOfInterest = regionOfInterest, DrawOnWindow = true, Name = "队伍名称", - DrawOnWindowPen= System.Drawing.Pens.White + DrawOnWindowPen = System.Drawing.Pens.White }; - // 逐页查找 + try { - for (var i = 0; i < 16; i++) // 6.0版本最多20个队伍 + for (var i = 0; i < 16; i++) // 6.0版本最多20个队伍 { using var page = CaptureToRectArea(); + var nameList = page.FindMulti(recognitionObject); - var partySwitchNameRaList = page.FindMulti(recognitionObject); - - if (partySwitchNameRaList == null || partySwitchNameRaList.Count <= 0) + if (nameList == null || nameList.Count <= 0) { Logger.LogInformation("管理队伍界面文字识别失败"); break; } - // 当前页存在则直接点击 - foreach (var textRegion in partySwitchNameRaList) + // 在当前页查找匹配 + var (match, score) = FindMatchInPage(page, nameList, partyName, useOcrMatch); + if (match != null) { - if (Regex.IsMatch(textRegion.Text, partyName)) - { - page.ClickTo(textRegion.Right + textRegion.Width, textRegion.Bottom); - await Delay(200, ct); - Logger.LogInformation("切换队伍成功: {Text}", textRegion.Text); - await ConfirmParty(page, ct, isInPartyViewUi); - - RunnerContext.Instance.ClearCombatScenes(); - return true; - } + page.ClickTo(match.Right + match.Width, match.Bottom); + await Delay(200, ct); + if (useOcrMatch) + Logger.LogInformation("切换队伍成功: {Text}(匹配分数: {Score:F4})", match.Text, score); + else + Logger.LogInformation("切换队伍成功: {Text}", match.Text); + await ConfirmParty(page, ct, isInPartyViewUi); + RunnerContext.Instance.ClearCombatScenes(); + return true; } - Region lowest = partySwitchNameRaList.Where(r => r.X > 35 * _assetScale && r.X < 100 * _assetScale).OrderBy(r => r.Y).Last(); + // 判断是否已遍历所有队伍 + var lowest = nameList + .Where(r => r.X > 35 * _assetScale && r.X < 100 * _assetScale) + .OrderBy(r => r.Y) + .LastOrDefault(); + if (lowest == null) + { + Logger.LogInformation("未找到符合坐标范围的队伍名称,跳过翻页判断"); + continue; + } lowest.DrawSelf("底部的队伍"); - if (lowest.Y < 777 * _assetScale) // 如果最底下是空队伍则不会有队伍名,以此判断是否已遍历完成 + if (lowest.Y < 777 * _assetScale) // 如果最底下是空队伍则不会有队伍名,以此判断是否已遍历完成 { Logger.LogInformation("已抵达最后一个队伍"); break; } - // 点击下一页 + // 翻页 if (i == 0) { - // #ebe4d8 首次点一下第一个,防止第五个被点击过 + // 首次点一下第一个,防止第五个被点击过 page.ClickTo(600 * _assetScale, 200 * _assetScale); - await Task.Delay(300, ct); // 等待动画 + await Task.Delay(300, ct); } - page.ClickTo(regionOfInterest.X + regionOfInterest.Width / 2, lowest.Bottom); // 点击最下方队伍下移 + page.ClickTo(regionOfInterest.X + regionOfInterest.Width / 2, lowest.Bottom); await Delay(400, ct); } } @@ -227,14 +147,181 @@ public class SwitchPartyTask // 未找到 Logger.LogError("未找到队伍: {Name},返回主界面", partyName); - Logger.LogInformation("如果找不到设定的队伍名,有可能是文字识别效果不佳,请尝试正则表达式"); + Logger.LogInformation(useOcrMatch + ? "如果找不到设定的队伍名,有可能是文字识别效果不佳,请尝试调整 OcrMatch 模糊匹配阈值" + : "如果找不到设定的队伍名,有可能是文字识别效果不佳,请尝试正则表达式"); await _returnMainUiTask.Start(ct); return false; } + /// + /// 确保队伍配置界面已打开。如果不在主界面则先返回主界面,然后打开队伍配置。 + /// + private async Task EnsurePartyViewOpen(ImageRegion currentScreen, CancellationToken ct) + { + if (!Bv.IsInMainUi(currentScreen)) + { + await _returnMainUiTask.Start(ct); + await Delay(200, ct); + using var raMain = CaptureToRectArea(); + if (!Bv.IsInMainUi(raMain)) + throw new InvalidOperationException("未能返回主界面"); + } + + const int maxAttempts = 2; + for (int attempt = 1; attempt <= maxAttempts; attempt++) + { + Simulation.SendInput.SimulateAction(GIActions.OpenPartySetupScreen); + for (int i = 0; i < 7; i++) // 考虑加载时间 2s,共检查 4.2s + { + await Delay(600, ct); + using var raCheck = CaptureToRectArea(); + if (Bv.IsInPartyViewUi(raCheck)) return; + } + } + + throw new PartySetupFailedException("未能打开队伍配置界面"); + } + + /// + /// 检查当前队伍名称是否匹配目标 + /// + private bool IsCurrentTeamMatch(ImageRegion ra, Region partyViewBtn, string partyName, bool useOcrMatch) + { + var roi = new Rect(partyViewBtn.Right, partyViewBtn.Top, (int)(350 * _assetScale), partyViewBtn.Height); + + if (useOcrMatch) + { + var matchService = OcrFactory.PaddleMatch; + var threshold = TaskContext.Instance().Config.OtherConfig.OcrConfig.OcrMatchDefaultThreshold; + using var region = ra.DeriveCrop(roi); + var score = matchService.OcrMatch(region.SrcMat, partyName); + Logger.LogInformation("切换队伍,当前队伍 OcrMatch 分数: {Score:F4},判断阈值: {Threshold}", score, threshold); + if (score >= threshold) + { + Logger.LogInformation("当前队伍即为目标队伍(匹配分数: {Score:F4}),无需切换", score); + return true; + } + + return false; + } + + var text = CleanOcrText(ra.Find(new RecognitionObject + { + RecognitionType = RecognitionTypes.Ocr, + RegionOfInterest = roi + }).Text); + Logger.LogInformation("切换队伍,当前队伍名称: {Text},使用正则表达式规则进行模糊匹配", text); + if (Regex.IsMatch(text, partyName)) + { + Logger.LogInformation("当前队伍[{Name}]即为目标队伍,无需切换", text); + return true; + } + + return false; + } + + /// + /// 在当前页的文字区域列表中查找匹配目标的队伍 + /// + private (Region? match, double score) FindMatchInPage( + ImageRegion page, List textRegions, string partyName, bool useOcrMatch) + { + if (useOcrMatch) + { + var matchService = OcrFactory.PaddleMatch; + var threshold = TaskContext.Instance().Config.OtherConfig.OcrConfig.OcrMatchDefaultThreshold; + Region? bestMatch = null; + double bestScore = 0; + var imgW = page.SrcMat.Width; + var imgH = page.SrcMat.Height; + foreach (var region in textRegions) + { + var cx = Math.Max(0, region.X); + var cy = Math.Max(0, region.Y); + var cw = Math.Min(region.Width, imgW - cx); + var ch = Math.Min(region.Height, imgH - cy); + if (cw <= 0 || ch <= 0) + continue; + + using var cropped = page.DeriveCrop(cx, cy, cw, ch); + var score = matchService.OcrMatchDirect(cropped.SrcMat, partyName); + if (score >= threshold && score > bestScore) + { + bestScore = score; + bestMatch = region; + } + } + + return (bestMatch, bestScore); + } + + foreach (var region in textRegions) + { + if (Regex.IsMatch(region.Text, partyName)) + return (region, 0); + } + + return (null, 0); + } + + /// + /// 打开队伍选择页面(点击选择按钮并等待加载) + /// + private static async Task OpenPartyChoosePage(Region partyViewBtn, CancellationToken ct) + { + var menu = await NewRetry.WaitForElementAppear( + ElementAssets.Instance.PartyBtnDelete, + () => partyViewBtn.Click(), + ct, 4, 500); + if (!menu) + throw new PartySetupFailedException("未能打开队伍选择页面"); + + Region? partyDeleteBtn = null; + var success = await NewRetry.WaitForAction(() => + { + using var ocrRa = CaptureToRectArea(); + partyDeleteBtn = ocrRa.Find(ElementAssets.Instance.PartyBtnDelete); + return partyDeleteBtn.IsExist(); + }, ct, 5); + + if (!success || partyDeleteBtn == null) + throw new PartySetupFailedException("未能打开队伍配置界面"); + + return partyDeleteBtn; + } + + /// + /// 滚动列表到最上方 + /// + private static async Task ScrollToTop(CancellationToken ct) + { + await Task.Delay(50, ct); + GameCaptureRegion.GameRegion1080PPosClick(700, 125); + await Task.Delay(50, ct); + Simulation.SendInput.Mouse.LeftButtonDown(); + await Task.Delay(450, ct); + Simulation.SendInput.Mouse.LeftButtonUp(); + await Task.Delay(100, ct); + } + + /// + /// 清理 OCR 识别结果中的干扰字符 + /// + private static string CleanOcrText(string? text) + { + if (string.IsNullOrEmpty(text)) + return string.Empty; + var cleaned = text.Replace("\"", "").Replace("\r\n", "").Replace("\r", ""); + var newLineIndex = cleaned.IndexOf('\n'); + if (newLineIndex != -1) + cleaned = cleaned[..newLineIndex]; + return cleaned.Trim(); + } + private async Task ConfirmParty(ImageRegion page, CancellationToken ct, bool isInPartyViewUi = false) { - var r1 = Bv.ClickWhiteConfirmButton(page.DeriveCrop(0, page.Height / 4, page.Width / 4, page.Height - page.Height / 4)); + Bv.ClickWhiteConfirmButton(page.DeriveCrop(0, page.Height / 4, page.Width / 4, page.Height - page.Height / 4)); var partyChooseUiClosed = await NewRetry.WaitForAction(() => { using var ra2 = CaptureToRectArea(); @@ -244,9 +331,10 @@ public class SwitchPartyTask { throw new PartySetupFailedException("选择队伍失败,等待队伍切换超时!"); } + await Delay(200, ct); using var ra = CaptureToRectArea(); - var r2 = Bv.ClickWhiteConfirmButton(ra.DeriveCrop(page.Width - page.Width / 4, page.Height / 4, page.Width / 4, page.Height - page.Height / 4)); + Bv.ClickWhiteConfirmButton(ra.DeriveCrop(page.Width - page.Width / 4, page.Height / 4, page.Width / 4, page.Height - page.Height / 4)); await Delay(500, ct); if (isInPartyViewUi) await _returnMainUiTask.Start(ct); } diff --git a/BetterGenshinImpact/Helpers/CacheHelper.cs b/BetterGenshinImpact/Helpers/CacheHelper.cs new file mode 100644 index 00000000..10a06804 --- /dev/null +++ b/BetterGenshinImpact/Helpers/CacheHelper.cs @@ -0,0 +1,103 @@ +using System; +using System.Collections.Generic; + +namespace BetterGenshinImpact.Helpers; + +public abstract class CacheHelper +{ + public class LruCache where TKey : notnull where TValue : class + { + private readonly int _capacity; + private readonly TimeSpan? _expireAfter; + private readonly Dictionary> _cacheMap; + private readonly LinkedList<(TKey Key, TValue Value, DateTime ExpireAt)> _lruList; + private readonly object _lock = new(); + + public LruCache(int capacity, TimeSpan? expireAfter = null) + { + ArgumentOutOfRangeException.ThrowIfNegativeOrZero(capacity); + _capacity = capacity; + _expireAfter = expireAfter; + _cacheMap = new Dictionary>(); + _lruList = []; + } + + public bool TryGet(TKey key, out TValue? value) + { + lock (_lock) + { + if (_cacheMap.TryGetValue(key, out var node)) + { + if (_expireAfter.HasValue && DateTime.UtcNow > node.Value.ExpireAt) + { + _lruList.Remove(node); + _cacheMap.Remove(key); + value = null; + return false; + } + _lruList.Remove(node); + _lruList.AddFirst(node); + value = node.Value.Value; + return true; + } + value = null; + return false; + } + } + + public void Set(TKey key, TValue value) + { + lock (_lock) + { + var expireAt = _expireAfter.HasValue ? DateTime.UtcNow.Add(_expireAfter.Value) : default; + + if (_cacheMap.TryGetValue(key, out var node)) + { + node.Value = (key, value, expireAt); + _lruList.Remove(node); + _lruList.AddFirst(node); + } + else + { + if (_cacheMap.Count >= _capacity) + { + var lru = _lruList.Last; + if (lru != null) + { + _cacheMap.Remove(lru.Value.Key); + _lruList.RemoveLast(); + } + } + var newNode = new LinkedListNode<(TKey, TValue, DateTime)>((key, value, expireAt)); + _lruList.AddFirst(newNode); + _cacheMap[key] = newNode; + } + } + } + + public bool Remove(TKey key) + { + lock (_lock) + { + if (!_cacheMap.TryGetValue(key, out var node)) return false; + _lruList.Remove(node); + _cacheMap.Remove(key); + return true; + } + } + + public int Count + { + get { lock (_lock) { return _cacheMap.Count; } } + } + + public void Clear() + { + lock (_lock) + { + _cacheMap.Clear(); + _lruList.Clear(); + } + } + } +} diff --git a/BetterGenshinImpact/View/Pages/CommonSettingsPage.xaml b/BetterGenshinImpact/View/Pages/CommonSettingsPage.xaml index 247162a6..1e511bec 100644 --- a/BetterGenshinImpact/View/Pages/CommonSettingsPage.xaml +++ b/BetterGenshinImpact/View/Pages/CommonSettingsPage.xaml @@ -21,8 +21,8 @@ - - + + @@ -434,19 +434,19 @@ Width="200" Margin="0,0,12,0" VerticalAlignment="Center" - Minimum="0" - Maximum="1" - TickFrequency="0.1" IsSnapToTickEnabled="True" + Maximum="1" + Minimum="0" + TickFrequency="0.1" Value="{Binding Config.MaskWindowConfig.TextOpacity, Mode=TwoWay}" /> @@ -1004,7 +1004,7 @@ - + - + @@ -1179,10 +1178,10 @@ Grid.Column="1" MinWidth="100" Margin="0,0,36,0" + DisplayMemberPath="Item2" ItemsSource="{Binding ServerTimeZones}" SelectedValue="{Binding Config.OtherConfig.ServerTimeZoneOffset}" - SelectedValuePath="Item1" - DisplayMemberPath="Item2" /> + SelectedValuePath="Item1" /> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/Test/BetterGenshinImpact.UnitTest/CoreTests/RecognitionTests/OCRTests/OcrMatchFallbackServiceTests.cs b/Test/BetterGenshinImpact.UnitTest/CoreTests/RecognitionTests/OCRTests/OcrMatchFallbackServiceTests.cs new file mode 100644 index 00000000..3e5dd144 --- /dev/null +++ b/Test/BetterGenshinImpact.UnitTest/CoreTests/RecognitionTests/OCRTests/OcrMatchFallbackServiceTests.cs @@ -0,0 +1,195 @@ +using BetterGenshinImpact.Core.Recognition.OCR; +using OpenCvSharp; + +namespace BetterGenshinImpact.UnitTest.CoreTests.RecognitionTests.OCRTests; + +public class OcrMatchFallbackServiceTests +{ + #region LevenshteinDistance + + [Fact] + public void LevenshteinDistance_IdenticalStrings_ReturnsZero() + { + Assert.Equal(0, OcrMatchFallbackService.LevenshteinDistance("abc", "abc")); + } + + [Fact] + public void LevenshteinDistance_EmptyAndNonEmpty_ReturnsLength() + { + Assert.Equal(3, OcrMatchFallbackService.LevenshteinDistance("", "abc")); + Assert.Equal(3, OcrMatchFallbackService.LevenshteinDistance("abc", "")); + } + + [Fact] + public void LevenshteinDistance_BothEmpty_ReturnsZero() + { + Assert.Equal(0, OcrMatchFallbackService.LevenshteinDistance("", "")); + } + + [Fact] + public void LevenshteinDistance_SingleSubstitution() + { + // "确认" vs "确忍" — 一个字符替换 + Assert.Equal(1, OcrMatchFallbackService.LevenshteinDistance("确认", "确忍")); + } + + [Fact] + public void LevenshteinDistance_Insertion() + { + Assert.Equal(1, OcrMatchFallbackService.LevenshteinDistance("ac", "abc")); + } + + [Fact] + public void LevenshteinDistance_Deletion() + { + Assert.Equal(1, OcrMatchFallbackService.LevenshteinDistance("abc", "ac")); + } + + [Fact] + public void LevenshteinDistance_CompletelyDifferent() + { + Assert.Equal(3, OcrMatchFallbackService.LevenshteinDistance("abc", "xyz")); + } + + #endregion + + #region ComputeTextSimilarity + + [Fact] + public void ComputeTextSimilarity_ExactMatch_ReturnsOne() + { + Assert.Equal(1.0, OcrMatchFallbackService.ComputeTextSimilarity("确认", "确认")); + } + + [Fact] + public void ComputeTextSimilarity_TextContainsTarget_ReturnsOne() + { + // "确认购买" 包含 "确认" + Assert.Equal(1.0, OcrMatchFallbackService.ComputeTextSimilarity("确认购买", "确认")); + } + + [Fact] + public void ComputeTextSimilarity_TargetContainsText_ReturnsRatio() + { + // "确认" 被 "确认购买" 包含,长度比 = 2/4 + Assert.Equal(0.5, OcrMatchFallbackService.ComputeTextSimilarity("确认", "确认购买")); + } + + [Fact] + public void ComputeTextSimilarity_EmptyTarget_ReturnsOne() + { + Assert.Equal(1.0, OcrMatchFallbackService.ComputeTextSimilarity("任意文字", "")); + } + + [Fact] + public void ComputeTextSimilarity_EmptyText_ReturnsZero() + { + Assert.Equal(0.0, OcrMatchFallbackService.ComputeTextSimilarity("", "确认")); + } + + [Fact] + public void ComputeTextSimilarity_SingleCharDifference() + { + // "确忍" vs "确认" — 距离1, 最大长度2, 相似度 = 1 - 1/2 = 0.5 + Assert.Equal(0.5, OcrMatchFallbackService.ComputeTextSimilarity("确忍", "确认")); + } + + [Fact] + public void ComputeTextSimilarity_CompletelyDifferent_ReturnsZero() + { + // 完全不同的字符串 + Assert.Equal(0.0, OcrMatchFallbackService.ComputeTextSimilarity("甲乙", "丙丁")); + } + + #endregion + + #region OcrMatch / OcrMatchDirect 集成测试(使用 FakeOcrService) + + [Fact] + public void OcrMatch_WhenRegionContainsTarget_ReturnsOne() + { + var fakeOcr = new FakeOcrService(new OcrResult([ + new OcrResultRegion(default, "确认购买", 0.9f) + ])); + var sut = new OcrMatchFallbackService(fakeOcr); + + using var mat = new Mat(50, 200, MatType.CV_8UC3, Scalar.White); + var score = sut.OcrMatch(mat, "确认"); + + Assert.Equal(1.0, score); + } + + [Fact] + public void OcrMatch_MultipleRegions_ReturnsBestScore() + { + var fakeOcr = new FakeOcrService(new OcrResult([ + new OcrResultRegion(default, "其他文字", 0.9f), + new OcrResultRegion(default, "确认", 0.9f) + ])); + var sut = new OcrMatchFallbackService(fakeOcr); + + using var mat = new Mat(50, 200, MatType.CV_8UC3, Scalar.White); + var score = sut.OcrMatch(mat, "确认"); + + Assert.Equal(1.0, score); + } + + [Fact] + public void OcrMatch_NoRegions_ReturnsZero() + { + var fakeOcr = new FakeOcrService(new OcrResult([])); + var sut = new OcrMatchFallbackService(fakeOcr); + + using var mat = new Mat(50, 200, MatType.CV_8UC3, Scalar.White); + var score = sut.OcrMatch(mat, "确认"); + + Assert.Equal(0.0, score); + } + + [Fact] + public void OcrMatchDirect_ExactMatch_ReturnsOne() + { + var fakeOcr = new FakeOcrService(ocrWithoutDetectorResult: "确认"); + var sut = new OcrMatchFallbackService(fakeOcr); + + using var mat = new Mat(50, 200, MatType.CV_8UC3, Scalar.White); + var score = sut.OcrMatchDirect(mat, "确认"); + + Assert.Equal(1.0, score); + } + + [Fact] + public void OcrMatchDirect_PartialMatch_ReturnsPartialScore() + { + var fakeOcr = new FakeOcrService(ocrWithoutDetectorResult: "确忍"); + var sut = new OcrMatchFallbackService(fakeOcr); + + using var mat = new Mat(50, 200, MatType.CV_8UC3, Scalar.White); + var score = sut.OcrMatchDirect(mat, "确认"); + + Assert.Equal(0.5, score, 0.01); + } + + #endregion + + /// + /// 用于测试 OcrMatchFallbackService 的假 IOcrService。 + /// + private class FakeOcrService : IOcrService + { + private readonly OcrResult? _ocrResult; + private readonly string _ocrWithoutDetectorResult; + + public FakeOcrService(OcrResult? ocrResult = null, string ocrWithoutDetectorResult = "") + { + _ocrResult = ocrResult; + _ocrWithoutDetectorResult = ocrWithoutDetectorResult; + } + + public string Ocr(Mat mat) => _ocrResult?.Text ?? ""; + + public string OcrWithoutDetector(Mat mat) => _ocrWithoutDetectorResult; + + public OcrResult OcrResult(Mat mat) => _ocrResult ?? new OcrResult([]); + } +} diff --git a/Test/BetterGenshinImpact.UnitTest/CoreTests/RecognitionTests/OCRTests/OcrUtilsTests.cs b/Test/BetterGenshinImpact.UnitTest/CoreTests/RecognitionTests/OCRTests/OcrUtilsTests.cs new file mode 100644 index 00000000..2e8ae821 --- /dev/null +++ b/Test/BetterGenshinImpact.UnitTest/CoreTests/RecognitionTests/OCRTests/OcrUtilsTests.cs @@ -0,0 +1,291 @@ +using BetterGenshinImpact.Core.Recognition.OCR.Engine; + +namespace BetterGenshinImpact.UnitTest.CoreTests.RecognitionTests.OCRTests; + +public class OcrUtilsTests +{ + #region CreateLabelDict + + [Fact] + public void CreateLabelDict_SingleCharLabels_MapsCorrectly() + { + // 标签 ["a","b","c"] → a=1, b=2, c=3, " "=4 + IReadOnlyList labels = ["a", "b", "c"]; + var dict = OcrUtils.CreateLabelDict(labels, out var lengths); + + Assert.Equal(1, dict["a"]); + Assert.Equal(2, dict["b"]); + Assert.Equal(3, dict["c"]); + Assert.Equal(4, dict[" "]); + // 所有标签都是长度1,labelLengths = [1] + Assert.Single(lengths); + Assert.Equal(1, lengths[0]); + } + + [Fact] + public void CreateLabelDict_NoZeroLength() + { + // 不应包含长度为0的项(防止无限循环) + IReadOnlyList labels = ["x", "y"]; + OcrUtils.CreateLabelDict(labels, out var lengths); + Assert.DoesNotContain(0, lengths); + } + + [Fact] + public void CreateLabelDict_LengthsDescendingOrder() + { + // 多字节标签时,labelLengths 应降序排列(先试长匹配) + IReadOnlyList labels = ["a", "ab", "b"]; + OcrUtils.CreateLabelDict(labels, out var lengths); + for (var i = 0; i < lengths.Length - 1; i++) + { + Assert.True(lengths[i] >= lengths[i + 1], "labelLengths 应为降序"); + } + } + + #endregion + + #region MapStringToLabelIndices + + [Fact] + public void MapStringToLabelIndices_SimpleMatch() + { + // labels: ["a","b","c"] → a=1, b=2, c=3 + IReadOnlyList labels = ["a", "b", "c"]; + var dict = OcrUtils.CreateLabelDict(labels, out var lengths); + + var result = OcrUtils.MapStringToLabelIndices("abc", dict, lengths); + + Assert.Equal([1, 2, 3], result); + } + + [Fact] + public void MapStringToLabelIndices_SkipsUnknownChars() + { + // "aXb" 中 X 不在标签里,应被跳过 + IReadOnlyList labels = ["a", "b"]; + var dict = OcrUtils.CreateLabelDict(labels, out var lengths); + + var result = OcrUtils.MapStringToLabelIndices("aXb", dict, lengths); + + Assert.Equal([1, 2], result); + } + + [Fact] + public void MapStringToLabelIndices_PrefersLongerMatch() + { + // 标签含 "ab" 和 "a",输入 "ab" 应优先匹配长标签 "ab" + IReadOnlyList labels = ["a", "ab", "b"]; + var dict = OcrUtils.CreateLabelDict(labels, out var lengths); + + var result = OcrUtils.MapStringToLabelIndices("ab", dict, lengths); + + // "ab" 整体匹配为 index 2(labels 中第2个元素) + Assert.Single(result); + Assert.Equal(2, result[0]); + } + + [Fact] + public void MapStringToLabelIndices_EmptyString_ReturnsEmpty() + { + IReadOnlyList labels = ["a", "b"]; + var dict = OcrUtils.CreateLabelDict(labels, out var lengths); + + var result = OcrUtils.MapStringToLabelIndices("", dict, lengths); + + Assert.Empty(result); + } + + [Fact] + public void MapStringToLabelIndices_AllUnknown_ReturnsEmpty() + { + IReadOnlyList labels = ["a", "b"]; + var dict = OcrUtils.CreateLabelDict(labels, out var lengths); + + var result = OcrUtils.MapStringToLabelIndices("XYZ", dict, lengths); + + Assert.Empty(result); + } + + [Fact] + public void MapStringToLabelIndices_SpaceChar_MapsToSpaceIndex() + { + // 空格字符映射到 labels.Count + 1 + IReadOnlyList labels = ["a", "b"]; + var dict = OcrUtils.CreateLabelDict(labels, out var lengths); + + var result = OcrUtils.MapStringToLabelIndices("a b", dict, lengths); + + // a=1, " "=3, b=2 + Assert.Equal([1, 3, 2], result); + } + + #endregion + + #region GetMaxScoreDP + + [Fact] + public void GetMaxScoreDP_PerfectMatch_ReturnsFullScore() + { + // result 中按顺序包含 target 的所有元素,置信度均为 1.0 + (int, float)[] result = [(1, 1.0f), (2, 1.0f), (3, 1.0f)]; + int[] target = [1, 2, 3]; + + var score = OcrUtils.GetMaxScoreDp(result, target, target.Length); + + Assert.Equal(1.0, score); + } + + [Fact] + public void GetMaxScoreDP_NoMatch_ReturnsZero() + { + // result 中不包含 target 的任何元素 + (int, float)[] result = [(4, 1.0f), (5, 1.0f)]; + int[] target = [1, 2]; + + var score = OcrUtils.GetMaxScoreDp(result, target, target.Length); + + Assert.Equal(0, score); + } + + [Fact] + public void GetMaxScoreDP_EmptyTarget_ReturnsZero() + { + (int, float)[] result = [(1, 1.0f)]; + int[] target = []; + + var score = OcrUtils.GetMaxScoreDp(result, target, 1); + + Assert.Equal(0, score); + } + + [Fact] + public void GetMaxScoreDP_PartialMatch_ReturnsZero() + { + // target 需要 [1,2,3],但 result 只有 [1,2],无法完整匹配 + (int, float)[] result = [(1, 1.0f), (2, 1.0f)]; + int[] target = [1, 2, 3]; + + var score = OcrUtils.GetMaxScoreDp(result, target, target.Length); + + Assert.Equal(0, score); + } + + [Fact] + public void GetMaxScoreDP_SubsequenceMatch_SkipsNoise() + { + // result 中有噪声,但子序列 [1,2,3] 可匹配 + (int, float)[] result = [(9, 0.5f), (1, 0.8f), (9, 0.3f), (2, 0.9f), (3, 0.7f)]; + int[] target = [1, 2, 3]; + + var score = OcrUtils.GetMaxScoreDp(result, target, target.Length); + + // (0.8 + 0.9 + 0.7) / 3 = 0.8 + Assert.Equal(0.8, score, 0.01); + } + + [Fact] + public void GetMaxScoreDP_PicksBestConfidence() + { + // target [1],result 中有两个 index=1,应选置信度最高的 + (int, float)[] result = [(1, 0.3f), (1, 0.9f)]; + int[] target = [1]; + + var score = OcrUtils.GetMaxScoreDp(result, target, 1); + + Assert.Equal(0.9, score, 0.01); + } + + [Fact] + public void GetMaxScoreDP_NormalizesWithAvailableCount() + { + // availableCount > target.Length 时分数被稀释 + (int, float)[] result = [(1, 1.0f), (2, 1.0f)]; + int[] target = [1, 2]; + + var score = OcrUtils.GetMaxScoreDp(result, target, 4); + + // (1.0 + 1.0) / 4 = 0.5 + Assert.Equal(0.5, score, 0.01); + } + + [Fact] + public void GetMaxScoreDP_ManyFrames_TargetLengthDenominator_ScoresHigh() + { + // 模拟多个文字区域的字符帧合并后做匹配,分母应为 target.Length + // 即使有很多噪声帧,只要 target 完整匹配,分数仍应很高 + (int, float)[] result = [ + (9, 0.5f), (8, 0.6f), (7, 0.4f), // 噪声区域1 + (1, 0.9f), (2, 0.85f), // 匹配目标 [1,2] + (6, 0.7f), (5, 0.3f), (4, 0.5f), // 噪声区域2 + (9, 0.2f), (8, 0.4f) // 噪声区域3 + ]; + int[] target = [1, 2]; + + // 使用 target.Length 作为分母:(0.9 + 0.85) / 2 = 0.875 + var score = OcrUtils.GetMaxScoreDp(result, target, target.Length); + + Assert.Equal(0.875, score, 0.01); + } + + #endregion + + #region CreateWeights + + [Fact] + public void CreateWeights_DefaultsToOne() + { + IReadOnlyList labels = ["a", "b", "c"]; + var labelDict = OcrUtils.CreateLabelDict(labels, out _); + var weights = OcrUtils.CreateWeights(new Dictionary(), labelDict, labels.Count); + + // labels.Count + 2 = 5 + Assert.Equal(5, weights.Length); + Assert.All(weights, w => Assert.Equal(1.0f, w)); + } + + [Fact] + public void CreateWeights_AppliesExtraWeights() + { + IReadOnlyList labels = ["a", "b", "c"]; + var extra = new Dictionary { { "b", 2.5f } }; + var labelDict = OcrUtils.CreateLabelDict(labels, out _); + + var weights = OcrUtils.CreateWeights(extra, labelDict, labels.Count); + + // "b" 是 labels[1],index=2 + Assert.Equal(1.0f, weights[1]); // "a" + Assert.Equal(2.5f, weights[2]); // "b" + Assert.Equal(1.0f, weights[3]); // "c" + } + + [Fact] + public void CreateWeights_IgnoresUnknownKeys() + { + IReadOnlyList labels = ["a", "b"]; + var extra = new Dictionary { { "z", 5.0f } }; + var labelDict = OcrUtils.CreateLabelDict(labels, out _); + + var weights = OcrUtils.CreateWeights(extra, labelDict, labels.Count); + + Assert.All(weights, w => Assert.Equal(1.0f, w)); + } + + [Fact] + public void CreateWeights_SpaceKey_MapsToCorrectIndex() + { + // 空格权重应写入 labels.Count + 1 位置,与 CreateLabelDict 一致 + IReadOnlyList labels = ["a", " ", "b"]; + var extra = new Dictionary { { " ", 3.0f } }; + var labelDict = OcrUtils.CreateLabelDict(labels, out _); + + var weights = OcrUtils.CreateWeights(extra, labelDict, labels.Count); + + // labels.Count + 1 = 4,空格权重应在 weights[4] + Assert.Equal(3.0f, weights[labels.Count + 1]); + // labels 中 " " 的位置 index=2(即 weights[2])不应被错误写入 + Assert.Equal(1.0f, weights[2]); + } + + #endregion +} diff --git a/Test/BetterGenshinImpact.UnitTest/HelperTests/LruCacheTests.cs b/Test/BetterGenshinImpact.UnitTest/HelperTests/LruCacheTests.cs new file mode 100644 index 00000000..2bff491f --- /dev/null +++ b/Test/BetterGenshinImpact.UnitTest/HelperTests/LruCacheTests.cs @@ -0,0 +1,74 @@ +using BetterGenshinImpact.Helpers; + +namespace BetterGenshinImpact.UnitTest.HelperTests; + +public class LruCacheTests +{ + [Fact] + public void BasicSetGetTest() + { + var cache = new CacheHelper.LruCache(3); + cache.Set("a", "1"); + cache.Set("b", "2"); + cache.Set("c", "3"); + Assert.True(cache.TryGet("a", out var v1) && v1 == "1"); + Assert.True(cache.TryGet("b", out var v2) && v2 == "2"); + Assert.True(cache.TryGet("c", out var v3) && v3 == "3"); + } + + [Fact] + public void LruEvictionTest() + { + var cache = new CacheHelper.LruCache(2); + cache.Set("a", "1"); + cache.Set("b", "2"); + cache.Set("c", "3"); // "a" 应被淘汰 + Assert.False(cache.TryGet("a", out _)); + Assert.True(cache.TryGet("b", out var v2) && v2 == "2"); + Assert.True(cache.TryGet("c", out var v3) && v3 == "3"); + } + + [Fact] + public void UpdateMovesToHeadTest() + { + var cache = new CacheHelper.LruCache(2); + cache.Set("a", "1"); + cache.Set("b", "2"); + cache.TryGet("a", out _); // a 变为最新 + cache.Set("c", "3"); // b 应被淘汰 + Assert.True(cache.TryGet("a", out _)); + Assert.False(cache.TryGet("b", out _)); + Assert.True(cache.TryGet("c", out _)); + } + + [Fact] + public void ExpireTest() + { + var cache = new CacheHelper.LruCache(2, TimeSpan.FromMilliseconds(500)); + cache.Set("a", "1"); + Assert.True(cache.TryGet("a", out var v) && v == "1"); + Thread.Sleep(650); + Assert.False(cache.TryGet("a", out _)); + } + + [Fact] + public void RemoveTest() + { + var cache = new CacheHelper.LruCache(2); + cache.Set("a", "1"); + Assert.True(cache.Remove("a")); + Assert.False(cache.TryGet("a", out _)); + } + + [Fact] + public void ClearTest() + { + var cache = new CacheHelper.LruCache(2); + Assert.Equal(0, cache.Count); + cache.Set("a", "1"); + cache.Set("b", "2"); + Assert.Equal(2, cache.Count); + cache.Clear(); + Assert.Equal(0, cache.Count); + } +}