文本识别的模糊匹配功能 (#2799)

* chore: add AGENTS.md to .gitignore

* feat(config): 新增 AllowDuplicateChar OCR配置项

* refactor(ocr): Rec 暴露protected成员、提取RunInference、支持AllowDuplicateChar

* feat(ocr): 打通 AllowDuplicateChar 参数链 PaddleOcrService → Rec

* feat(ocr): OcrUtils 新增 CreateLabelDict/CreateWeights 工具方法

* feat(helpers): 新增 LruCache 缓存工具类

* feat(ocr): 新增 RecMatch DP模糊匹配识别器

* test(helpers): 新增 LruCache 单元测试

* test(ocr): 新增 RecMatch.GetTarget / CreateLabelDict 单元测试

* fix(ocr): 修复 RecMatch 中权重矩阵乘法的使用方式

* refactor(ocr): 合并 RecMatch 到 Rec,提取可测试静态方法,补充单元测试

将 RecMatch 子类合并到 Rec 中,消除继承关系和重复的批处理逻辑(提取 RunBatch<T>)。
将 GetTarget 核心逻辑和 GetMaxScoreDP 提取为 OcrUtils 静态方法以便独立测试。
重命名测试文件并新增 16 个单元测试覆盖 MapStringToLabelIndices、GetMaxScoreDP、CreateWeights。

* feat(ocr): 将 Rec.RunMatch 暴露给 JS 引擎和内部 C# 代码

新增 IOcrMatchService 接口,提供基于 DP 模糊匹配的 OcrMatch/OcrMatchDirect 方法,
返回 0~1 置信度分数。PaddleOcrService 实现该接口,OcrFactory.PaddleMatch 保证
非 null 返回(引擎不支持时自动回退到普通 OCR + 编辑距离字符串比较)。
BvPage 新增 OcrMatch/WaitForOcrMatch 供 JS 脚本使用,阈值可通过配置调整。

* feat(ui): 为 OCR 配置添加允许重复字符和模糊匹配阈值的设置项

在通用设置页 OCR 配置区域新增两个控件:
- 允许连续重复字符(AllowDuplicateChar)开关
- OCR模糊匹配阈值(OcrMatchDefaultThreshold)输入框

* fix: 修复 PR #2799 代码审查中发现的多项问题

- 修复 Rec.cs 空文本时 score/sb.Length 除零产生 NaN
- 修复 BvPage.cs rect==default 时同一对象被双重 Dispose
- 移除 Rec.cs Finalizer 避免 GC 线程加锁死锁
- 移除 CacheHelper WeakKey 无效功能,简化为直接 Dictionary 查找
- 添加 weights 数组长度与模型输出维度校验
- 修复 CreateLabelDict 空格标签索引冲突
- 修复 GetMaxScoreDP availableCount=0 除零
- 修复 OcrMatchFallbackService Contains 大小写敏感
- 修复 BvPage.cs DefaultRetryInterval=0 除零
- 添加 OcrMatchDefaultThreshold [0,1] 范围约束
- 提取 PaddleOcrService BGRA→BGR 转换辅助方法
- 使用 Interlocked.CompareExchange 修复 OcrFactory Fallback 线程安全
- 增大 LruCacheTests BuilderTest TTL 裕量避免 CI 不稳定
- 更新 .gitignore 注释

* fix: 修复 OcrMatch 归一化分母导致多区域匹配分数过低的 bug,改进 UI

- 修复 GetMaxScoreFlat 中 availableCount 使用非空图像数作为分母,
  导致多文字区域场景下匹配分数被过度稀释的问题,改为使用 target.Length
- AllowDuplicateChar 设置项添加"需重新加载OCR引擎"的提示
- OCR模糊匹配阈值控件从 TextBox 改为 Slider + 数值显示
- 移除 Det 类中有问题的 finalizer(含锁的析构函数可能导致死锁)
- 补充多区域场景的单元测试

* feat(ocr): 添加队伍切换时使用OcrMatch模糊匹配的选项和相关配置

* fix(ui): 更新匹配成功阈值默认值为 0.8

* fix(ocr): 修复队伍切换逻辑中的空值处理和优化代码结构

* refactor: 简化 LruCache,移除弱引用支持和 Builder 模式

- 移除有 TOCTOU bug 的 WeakReference 支持(且无实际使用方)
- CacheItem 类改为 ValueTuple 减少堆分配
- 无过期时不再赋值 DateTime.MaxValue,过期检查短路跳过
- 移除仅剩两参数的 LruCacheBuilder,直接使用构造函数

* fix(ocr): 修复 CreateWeights 中空格字符权重写入错误索引的 bug

复用 CreateLabelDict 构建索引映射,确保空格映射到 labels.Count+1,
与 CreateLabelDict 保持一致。添加对应测试用例。

* fix(ocr): 修复 GCHandle.Alloc 失败时 finally 中 Free 掩盖原始异常的问题

* fix(ocr): 添加队伍选择按钮存在性检查,避免 PartySetupFailedException

* fix(ocr): 调整 OcrMatchDefaultThreshold 的 TickFrequency 为 0.01

* fix(ocr): 修复区域裁剪逻辑,确保裁剪尺寸不为负值

* fix(ocr): 优化字符置信度提取逻辑,直接按目标字符索引查找置信度

* fix(ocr): 修正变量命名以保持一致性,调整方法名大小写

* fix(ocr): 修改 CreateWeights 方法以使用标签字典和标签计数,优化权重创建逻辑

* fix(ocr): 更新 OCR 置信度阈值设置,确保阈值范围为 0.01 到 0.99,并优化相关逻辑
This commit is contained in:
Takaranoao
2026-02-19 23:08:46 -08:00
committed by GitHub
parent 20fe152630
commit e9d11f7267
16 changed files with 1702 additions and 296 deletions

2
.gitignore vendored
View File

@@ -28,7 +28,7 @@ github_actions_cache/
*.zip
# Rider
# IDE & AI tools
.idea
.trae
.claude

View File

@@ -2,7 +2,9 @@ using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using BetterGenshinImpact.Core.Recognition;
using BetterGenshinImpact.Core.Recognition.OCR;
using BetterGenshinImpact.Core.Simulator;
using BetterGenshinImpact.GameTask;
using BetterGenshinImpact.GameTask.Common;
using BetterGenshinImpact.GameTask.Model.Area;
using Fischless.WindowsInput;
@@ -118,4 +120,56 @@ public class BvPage
{
GameCaptureRegion.GameRegion1080PPosClick(x, y);
}
}
/// <summary>
/// 使用模糊匹配判断截图中是否包含目标文字。
/// 通过 <see cref="OcrFactory.PaddleMatch"/> 自动选择最佳实现DP 模糊匹配或普通 OCR + 字符串比较)。
/// </summary>
/// <param name="target">目标字符串</param>
/// <param name="rect">感兴趣区域default 表示全屏</param>
/// <param name="threshold">匹配阈值 (0~1)null 使用配置中的默认阈值</param>
/// <returns>是否匹配成功</returns>
public bool OcrMatch(string target, Rect rect = default, double? threshold = null)
{
var matchService = OcrFactory.PaddleMatch;
var actualThreshold = threshold
?? TaskContext.Instance().Config.OtherConfig.OcrConfig.OcrMatchDefaultThreshold;
var screen = TaskControl.CaptureToRectArea();
try
{
var roi = rect == default ? null : screen.DeriveCrop(rect);
try
{
var score = matchService.OcrMatch((roi ?? screen).SrcMat, target);
return score >= actualThreshold;
}
finally
{
roi?.Dispose();
}
}
finally
{
screen.Dispose();
}
}
/// <summary>
/// 重复截图并使用模糊匹配,等待目标文字出现。
/// 超时返回 false 而非抛异常。
/// </summary>
/// <param name="target">目标字符串</param>
/// <param name="rect">感兴趣区域default 表示全屏</param>
/// <param name="threshold">匹配阈值 (0~1)null 使用配置中的默认阈值</param>
/// <param name="timeout">超时时间毫秒null 使用 DefaultTimeout</param>
/// <returns>是否在超时前匹配成功</returns>
public async Task<bool> WaitForOcrMatch(string target, Rect rect = default, double? threshold = null, int? timeout = null)
{
var actualTimeout = timeout ?? DefaultTimeout;
var retryCount = DefaultRetryInterval > 0 ? actualTimeout / DefaultRetryInterval : 1;
return await NewRetry.WaitForAction(() => OcrMatch(target, rect, threshold),
_cancellationToken, retryCount, DefaultRetryInterval);
}
}

View File

@@ -106,6 +106,46 @@ public partial class OtherConfig : ObservableObject
/// </summary>
[ObservableProperty]
private PaddleOcrModelConfig _paddleOcrModelConfig = PaddleOcrModelConfig.V4Auto;
/// <summary>
/// 允许OCR结果中出现连续重复字符关闭CTC重复字符折叠
/// </summary>
[ObservableProperty]
private bool _allowDuplicateChar;
/// <summary>
/// 切换队伍时使用 OcrMatch 模糊匹配代替正则表达式匹配
/// </summary>
[ObservableProperty]
private bool _useOcrMatchForPartySwitch = true;
/// <summary>
/// OcrMatch 模糊匹配的默认阈值 (0~1),分数 ≥ 阈值视为匹配成功
/// </summary>
[ObservableProperty]
private double _ocrMatchDefaultThreshold = 0.8;
partial void OnOcrMatchDefaultThresholdChanged(double value)
{
if (value is <= 0 or > 1)
{
OcrMatchDefaultThreshold = Math.Clamp(value, 0.01, 1);
}
}
/// <summary>
/// PaddleOCR 识别置信度阈值 (0~1),低于此阈值的字符将被过滤
/// </summary>
[ObservableProperty]
private double _paddleOcrThreshold = 0.5;
partial void OnPaddleOcrThresholdChanged(double value)
{
if (value is < 0 or >= 1)
{
PaddleOcrThreshold = Math.Clamp(value, 0, 0.99);
}
}
}
//public partial class OtherConfig : ObservableObject

View File

@@ -16,9 +16,9 @@ public static class OcrUtils
/// 预处理速度比unsafe快5倍以上,且吃的资源还少
/// </summary>
/// <param name="inputImage">输入图像,若不是灰度图会转换</param>
/// <param name="tensorMemoryOwnser">tensor的Memory用完需要释放</param>
/// <param name="tensorMemoryOwner">tensor的Memory用完需要释放</param>
/// <returns></returns>
public static Tensor<float> ToTensorYapDnn(Mat inputImage, out IMemoryOwner<float> tensorMemoryOwnser)
public static Tensor<float> ToTensorYapDnn(Mat inputImage, out IMemoryOwner<float> tensorMemoryOwner)
{
using var rt = new ResourcesTracker();
Mat dst;
@@ -40,10 +40,10 @@ public static class OcrUtils
// 使用向量运算代替循环
var blob = rt.T(CvDnn.BlobFromImage(padded, 1.0 / 255.0, default, default, false, false));
var nCols = padded.Cols * padded.Rows;
tensorMemoryOwnser = MemoryPool<float>.Shared.Rent(nCols);
tensorMemoryOwner = MemoryPool<float>.Shared.Rent(nCols);
// 内存复制,如果直接传指针构建的话速度还不如多复制一份
blob.AsSpan<float>().CopyTo(tensorMemoryOwnser.Memory.Span);
return new DenseTensor<float>(tensorMemoryOwnser.Memory[..nCols], [1, 1, 32, 384]);
blob.AsSpan<float>().CopyTo(tensorMemoryOwner.Memory.Span);
return new DenseTensor<float>(tensorMemoryOwner.Memory[..nCols], [1, 1, 32, 384]);
}
/// <summary>
@@ -180,6 +180,119 @@ public static class OcrUtils
};
}
/// <summary>
/// 从标签列表构建字符串→索引字典,供 Rec 模糊匹配使用。
/// 索引从1开始0为CTC空白符空格字符为 labels.Count+1。
/// </summary>
/// <param name="labels">识别模型的标签列表</param>
/// <param name="labelLengths">各标签的字符长度集合(降序排列,用于从长到短贪心匹配)</param>
public static IReadOnlyDictionary<string, int> CreateLabelDict(
IReadOnlyList<string> labels, out int[] labelLengths)
{
var dict = new Dictionary<string, int>();
var lengths = new HashSet<int>();
for (var i = 0; i < labels.Count; i++)
{
if (labels[i] == " ") continue;
var len = labels[i].Length;
if (len > 0) lengths.Add(len);
dict[labels[i]] = i + 1;
}
// 空格字符对应索引 labels.Count + 1
dict[" "] = labels.Count + 1;
lengths.Add(1);
// 降序:先尝试更长的标签
labelLengths = lengths.OrderByDescending(x => x).ToArray();
return dict;
}
/// <summary>
/// 根据额外权重字典,创建与标签列表等长的权重数组(用于加权推理分数)。
/// 未指定权重的标签默认为 1.0。
/// </summary>
public static float[] CreateWeights(
Dictionary<string, float> extraWeights, IReadOnlyDictionary<string, int> labelDict, int labelCount)
{
var result = new float[labelCount + 2];
Array.Fill(result, 1.0f);
foreach (var (key, value) in extraWeights)
{
if (!labelDict.TryGetValue(key, out var index)) continue;
if (index >= 0 && index < result.Length)
{
result[index] = value;
}
}
return result;
}
/// <summary>
/// 将目标字符串映射为标签索引序列。
/// 使用贪心从长到短匹配,无法映射的字符会被跳过。
/// </summary>
/// <param name="target">目标字符串</param>
/// <param name="labelDict">标签→索引字典(由 CreateLabelDict 生成)</param>
/// <param name="labelLengths">标签长度集合,降序排列(由 CreateLabelDict 生成)</param>
public static int[] MapStringToLabelIndices(
string target,
IReadOnlyDictionary<string, int> labelDict,
int[] labelLengths)
{
var chars = target.ToCharArray();
var targetIndices = new int[chars.Length];
Array.Fill(targetIndices, -1);
var index = 0;
while (index < chars.Length)
{
var found = false;
foreach (var labelLength in labelLengths)
{
if (index + labelLength > chars.Length) continue;
var subStr = new string(chars, index, labelLength);
if (!labelDict.TryGetValue(subStr, out var labelIndex)) continue;
targetIndices[index] = labelIndex;
index += labelLength;
found = true;
break;
}
if (!found) index++;
}
return targetIndices.Where(x => x != -1).ToArray();
}
/// <summary>
/// 动态规划最大子序列匹配。
/// 在 result 序列中找到 target 的最大置信度子序列匹配,返回归一化分数 (0~1)。
/// </summary>
/// <param name="result">OCR 输出的 (labelIndex, confidence) 序列</param>
/// <param name="target">目标标签索引序列</param>
/// <param name="availableCount">归一化分母(通常为 target.Length得到每个目标字符的平均置信度</param>
public static double GetMaxScoreDp((int, float)[] result, int[] target, int availableCount)
{
if (target.Length == 0 || availableCount <= 0) return 0;
var dp = new double[target.Length + 1];
dp[0] = 0;
for (var j = 1; j <= target.Length; j++)
dp[j] = -255d; // 不可达
foreach (var (index, confidence) in result)
{
// 逆序更新,避免同一 result 元素被多次使用
for (var j = target.Length; j >= 1; j--)
{
if (index != target[j - 1]) continue;
if (!(dp[j - 1] > -200)) continue; // 前序不可达
var newSum = dp[j - 1] + confidence;
if (newSum > dp[j]) dp[j] = newSum;
}
}
if (dp[target.Length] <= -200) return 0; // 无法完整匹配
return dp[target.Length] / availableCount;
}
public static Mat Tensor2Mat(Tensor<float> tensor)
{
var dimensions = tensor.Dimensions;

View File

@@ -0,0 +1,26 @@
using OpenCvSharp;
namespace BetterGenshinImpact.Core.Recognition.OCR;
/// <summary>
/// 基于 DP 模糊匹配的 OCR 服务接口,返回匹配置信度分数 (0~1)。
/// 独立于 IOcrService仅由支持模糊匹配的引擎实现。
/// </summary>
public interface IOcrMatchService
{
/// <summary>
/// 使用检测器定位文字区域后,对每个区域进行模糊匹配,返回最高置信度 (0~1)。
/// </summary>
/// <param name="mat">输入图像(推荐三通道 BGR</param>
/// <param name="target">目标字符串</param>
/// <returns>匹配置信度0 表示完全不匹配1 表示完全匹配</returns>
double OcrMatch(Mat mat, string target);
/// <summary>
/// 不使用检测器,直接对整张图像进行模糊匹配,返回置信度 (0~1)。
/// </summary>
/// <param name="mat">输入图像(推荐三通道 BGR</param>
/// <param name="target">目标字符串</param>
/// <returns>匹配置信度0 表示完全不匹配1 表示完全匹配</returns>
double OcrMatchDirect(Mat mat, string target);
}

View File

@@ -1,5 +1,6 @@
using System;
using System.Globalization;
using System.Threading;
using System.Threading.Tasks;
using BetterGenshinImpact.Core.Config;
using BetterGenshinImpact.Core.Recognition.OCR.Paddle;
@@ -18,7 +19,27 @@ public class OcrFactory : IDisposable
public static IOcrService Paddle => App.ServiceProvider.GetRequiredService<OcrFactory>().PaddleOcr;
private IOcrService PaddleOcr => _paddleOcrService ??= Create(OcrEngineTypes.Paddle);
/// <summary>
/// 获取支持模糊匹配的 OCR 服务。
/// 若引擎原生支持 IOcrMatchService 则直接返回,否则回退到普通 OCR + 字符串相似度。
/// 访问此属性会触发 Paddle 引擎的懒加载。
/// </summary>
public static IOcrMatchService PaddleMatch
{
get
{
var factory = App.ServiceProvider.GetRequiredService<OcrFactory>();
var service = factory.PaddleOcr;
if (service is IOcrMatchService matchService)
return matchService;
var fallback = new OcrMatchFallbackService(service);
return Interlocked.CompareExchange(ref factory._paddleOcrMatchFallback, fallback, null)
?? fallback;
}
}
private IOcrService? _paddleOcrService;
private IOcrMatchService? _paddleOcrMatchFallback;
private readonly ILogger<BgiOnnxFactory> _logger;
private readonly OtherConfig.Ocr _config;
@@ -87,34 +108,33 @@ public class OcrFactory : IDisposable
private PaddleOcrService CreatePaddleOcrInstance()
{
var allowDuplicateChar = _config.AllowDuplicateChar;
var threshold = (float)_config.PaddleOcrThreshold;
var factory = App.ServiceProvider.GetRequiredService<BgiOnnxFactory>();
return _config.PaddleOcrModelConfig switch
{
PaddleOcrModelConfig.V4Auto =>
new PaddleOcrService(App.ServiceProvider.GetRequiredService<BgiOnnxFactory>(),
new PaddleOcrService(factory,
PaddleOcrService.PaddleOcrModelType.FromCultureInfoV4(GetCultureInfo()) ??
PaddleOcrService.PaddleOcrModelType.V4),
PaddleOcrService.PaddleOcrModelType.V4,
allowDuplicateChar, threshold),
PaddleOcrModelConfig.V5Auto =>
new PaddleOcrService(App.ServiceProvider.GetRequiredService<BgiOnnxFactory>(),
new PaddleOcrService(factory,
PaddleOcrService.PaddleOcrModelType.FromCultureInfo(GetCultureInfo()) ??
PaddleOcrService.PaddleOcrModelType.V5),
PaddleOcrService.PaddleOcrModelType.V5,
allowDuplicateChar, threshold),
PaddleOcrModelConfig.V5 =>
new PaddleOcrService(App.ServiceProvider.GetRequiredService<BgiOnnxFactory>(),
PaddleOcrService.PaddleOcrModelType.V5),
new PaddleOcrService(factory, PaddleOcrService.PaddleOcrModelType.V5, allowDuplicateChar, threshold),
PaddleOcrModelConfig.V4 =>
new PaddleOcrService(App.ServiceProvider.GetRequiredService<BgiOnnxFactory>(),
PaddleOcrService.PaddleOcrModelType.V4),
new PaddleOcrService(factory, PaddleOcrService.PaddleOcrModelType.V4, allowDuplicateChar, threshold),
PaddleOcrModelConfig.V4En =>
new PaddleOcrService(App.ServiceProvider.GetRequiredService<BgiOnnxFactory>(),
PaddleOcrService.PaddleOcrModelType.V4En),
new PaddleOcrService(factory, PaddleOcrService.PaddleOcrModelType.V4En, allowDuplicateChar, threshold),
PaddleOcrModelConfig.V5Korean =>
new PaddleOcrService(App.ServiceProvider.GetRequiredService<BgiOnnxFactory>(),
PaddleOcrService.PaddleOcrModelType.V5Korean),
new PaddleOcrService(factory, PaddleOcrService.PaddleOcrModelType.V5Korean, allowDuplicateChar, threshold),
PaddleOcrModelConfig.V5Latin =>
new PaddleOcrService(App.ServiceProvider.GetRequiredService<BgiOnnxFactory>(),
PaddleOcrService.PaddleOcrModelType.V5Latin),
new PaddleOcrService(factory, PaddleOcrService.PaddleOcrModelType.V5Latin, allowDuplicateChar, threshold),
PaddleOcrModelConfig.V5Eslav =>
new PaddleOcrService(App.ServiceProvider.GetRequiredService<BgiOnnxFactory>(),
PaddleOcrService.PaddleOcrModelType.V5Eslav),
new PaddleOcrService(factory, PaddleOcrService.PaddleOcrModelType.V5Eslav, allowDuplicateChar, threshold),
_ => throw new ArgumentOutOfRangeException(nameof(_config.PaddleOcrModelConfig),
_config.PaddleOcrModelConfig, "不支持的 Paddle OCR 模型配置")
};
@@ -123,6 +143,7 @@ public class OcrFactory : IDisposable
public Task Unload()
{
_paddleOcrMatchFallback = null;
if (_paddleOcrService is not IDisposable disposable)
{
_paddleOcrService = null;

View File

@@ -0,0 +1,99 @@
using System;
using System.Diagnostics;
using OpenCvSharp;
namespace BetterGenshinImpact.Core.Recognition.OCR;
/// <summary>
/// 当 OCR 引擎不支持 IOcrMatchService 时的回退实现。
/// 使用普通 OCR 识别文字后,通过字符串相似度进行匹配。
/// </summary>
public class OcrMatchFallbackService : IOcrMatchService
{
private readonly IOcrService _ocrService;
public OcrMatchFallbackService(IOcrService ocrService)
{
_ocrService = ocrService;
}
public double OcrMatch(Mat mat, string target)
{
var startTime = Stopwatch.GetTimestamp();
var ocrResult = _ocrService.OcrResult(mat);
var score = ComputeBestTextSimilarity(ocrResult, target);
var time = Stopwatch.GetElapsedTime(startTime);
Debug.WriteLine($"OcrMatchFallback 耗时 {time.TotalMilliseconds}ms 目标: {target} 分数: {score:F4}");
return score;
}
public double OcrMatchDirect(Mat mat, string target)
{
var startTime = Stopwatch.GetTimestamp();
var text = _ocrService.OcrWithoutDetector(mat);
var score = ComputeTextSimilarity(text, target);
var time = Stopwatch.GetElapsedTime(startTime);
Debug.WriteLine($"OcrMatchDirectFallback 耗时 {time.TotalMilliseconds}ms 目标: {target} 分数: {score:F4}");
return score;
}
/// <summary>
/// 在 OCR 结果的所有区域中找到与目标字符串最相似的分数。
/// </summary>
private static double ComputeBestTextSimilarity(OcrResult ocrResult, string target)
{
double bestScore = 0;
foreach (var region in ocrResult.Regions)
{
var score = ComputeTextSimilarity(region.Text, target);
if (score > bestScore) bestScore = score;
if (score >= 1.0) break;
}
return bestScore;
}
/// <summary>
/// 计算两个字符串的相似度 (0~1)。
/// 优先检查子串包含关系,否则使用编辑距离计算。
/// </summary>
public static double ComputeTextSimilarity(string text, string target)
{
if (string.IsNullOrEmpty(target)) return 1.0;
if (string.IsNullOrEmpty(text)) return 0.0;
if (text.Contains(target, StringComparison.OrdinalIgnoreCase)) return 1.0;
if (target.Contains(text, StringComparison.OrdinalIgnoreCase)) return (double)text.Length / target.Length;
var distance = LevenshteinDistance(text, target);
var maxLen = Math.Max(text.Length, target.Length);
return 1.0 - (double)distance / maxLen;
}
/// <summary>
/// 计算两个字符串之间的编辑距离Levenshtein Distance
/// </summary>
public static int LevenshteinDistance(string s, string t)
{
var sLen = s.Length;
var tLen = t.Length;
var prev = new int[tLen + 1];
var curr = new int[tLen + 1];
for (var j = 0; j <= tLen; j++)
prev[j] = j;
for (var i = 1; i <= sLen; i++)
{
curr[0] = i;
for (var j = 1; j <= tLen; j++)
{
var cost = s[i - 1] == t[j - 1] ? 0 : 1;
curr[j] = Math.Min(Math.Min(curr[j - 1] + 1, prev[j] + 1), prev[j - 1] + cost);
}
(prev, curr) = (curr, prev);
}
return prev[tLen];
}
}

View File

@@ -30,15 +30,7 @@ public class Det(BgiOnnxModel model, OcrVersionConfig config, BgiOnnxFactory bgi
/// <summary>Gets or sets the ratio for enlarging text boxes during post-processing.</summary>
public float UnclipRatio { get; set; } = 2.0f;
~Det()
{
lock (_session)
{
_session.Dispose();
}
}
public void Dispose()
{
lock (_session)

View File

@@ -15,7 +15,7 @@ using Size = OpenCvSharp.Size;
namespace BetterGenshinImpact.Core.Recognition.OCR.Paddle;
public class PaddleOcrService : IOcrService, IDisposable
public class PaddleOcrService : IOcrService, IOcrMatchService, IDisposable
{
/// <summary>
/// Usage:
@@ -103,11 +103,11 @@ public class PaddleOcrService : IOcrService, IDisposable
TestImagePath);
}
public (Det, Rec) Build(BgiOnnxFactory onnxFactory)
public (Det, Rec) Build(BgiOnnxFactory onnxFactory, bool allowDuplicateChar = false, float threshold = 0.5f)
{
return (
new Det(DetectionModel, DetectionVersion, onnxFactory),
new Rec(RecognitionModel, RecLabel(), RecognitionVersion, onnxFactory));
new Rec(RecognitionModel, RecLabel(), RecognitionVersion, onnxFactory, allowDuplicateChar, threshold: threshold));
}
public static readonly PaddleOcrModelType V4 = Create(
@@ -239,9 +239,10 @@ public class PaddleOcrService : IOcrService, IDisposable
}
}
public PaddleOcrService(BgiOnnxFactory bgiOnnxFactory, PaddleOcrModelType modelType)
public PaddleOcrService(BgiOnnxFactory bgiOnnxFactory, PaddleOcrModelType modelType,
bool allowDuplicateChar = false, float threshold = 0.5f)
{
var (modelsDet, modelsRec) = modelType.Build(bgiOnnxFactory);
var (modelsDet, modelsRec) = modelType.Build(bgiOnnxFactory, allowDuplicateChar, threshold);
_localDetModel = modelsDet;
_localRecModel = modelsRec;
@@ -267,13 +268,8 @@ public class PaddleOcrService : IOcrService, IDisposable
/// </summary>
public OcrResult OcrResult(Mat mat)
{
if (mat.Channels() == 4)
{
using var mat3 = mat.CvtColor(ColorConversionCodes.BGRA2BGR);
return _OcrResult(mat3);
}
return _OcrResult(mat);
using var converted = ConvertBgrIfNeeded(mat);
return _OcrResult(converted ?? mat);
}
/// <summary>
@@ -338,6 +334,61 @@ public class PaddleOcrService : IOcrService, IDisposable
Math.Clamp(rect.Bottom, 0, size.Height));
}
/// <summary>
/// 若输入为 BGRA 则转换为 BGR否则返回 null。
/// 调用方需在使用后 Dispose 返回的 Mat若非 null
/// </summary>
private static Mat? ConvertBgrIfNeeded(Mat mat)
{
return mat.Channels() == 4 ? mat.CvtColor(ColorConversionCodes.BGRA2BGR) : null;
}
/// <summary>
/// 使用检测器定位文字区域后,对每个区域进行 DP 模糊匹配,返回最高置信度 (0~1)。
/// </summary>
public double OcrMatch(Mat mat, string target)
{
var startTime = Stopwatch.GetTimestamp();
using var src = ConvertBgrIfNeeded(mat);
var bgr = src ?? mat;
var rects = _localDetModel.Run(bgr);
Mat[] mats = rects.Select(rect =>
{
var roi = bgr[GetCropedRect(rect.BoundingRect(), bgr.Size())];
return roi;
}).ToArray();
try
{
var score = _localRecModel.RunMatch(mats, target);
var time = Stopwatch.GetElapsedTime(startTime);
Debug.WriteLine($"PaddleOcrMatch 耗时 {time.TotalMilliseconds}ms 目标: {target} 分数: {score:F4}");
return score;
}
finally
{
foreach (var m in mats) m.Dispose();
}
}
/// <summary>
/// 不使用检测器,直接对整张图像进行 DP 模糊匹配,返回置信度 (0~1)。
/// </summary>
public double OcrMatchDirect(Mat mat, string target)
{
var startTime = Stopwatch.GetTimestamp();
using var src = ConvertBgrIfNeeded(mat);
var bgr = src ?? mat;
var score = _localRecModel.RunMatch([bgr], target);
var time = Stopwatch.GetElapsedTime(startTime);
Debug.WriteLine($"PaddleOcrMatchDirect 耗时 {time.TotalMilliseconds}ms 目标: {target} 分数: {score:F4}");
return score;
}
public void Dispose()
{
_localDetModel.Dispose();

View File

@@ -7,22 +7,66 @@ using System.Text;
using BetterGenshinImpact.Core.Recognition.OCR.Engine;
using BetterGenshinImpact.Core.Recognition.OCR.Engine.data;
using BetterGenshinImpact.Core.Recognition.ONNX;
using BetterGenshinImpact.Helpers;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using OpenCvSharp;
namespace BetterGenshinImpact.Core.Recognition.OCR.Paddle;
public class Rec(
BgiOnnxModel model,
IReadOnlyList<string> labels,
OcrVersionConfig config,
BgiOnnxFactory bgiOnnxFactory)
: IDisposable
/// <summary>
/// OCR 识别器,支持标准文字识别和基于动态规划的模糊匹配。
/// 模糊匹配将目标字符串与模型原始输出序列做子序列匹配,返回 0~1 的置信度分数,
/// 比先识别再字符串匹配更能容忍 OCR 噪声。
/// </summary>
public class Rec : IDisposable
{
private readonly InferenceSession _session = bgiOnnxFactory.CreateInferenceSession(model, true);
private readonly InferenceSession _session;
private readonly IReadOnlyList<string> _labels;
private readonly OcrVersionConfig _config;
private readonly bool _allowDuplicateChar;
private readonly float _threshold;
// _labels = File.ReadAllLines(labelFilePath);
// 模糊匹配相关字段
/// <summary>标签长度集合(降序),用于从长到短贪心匹配目标字符串</summary>
private readonly int[] _labelLengths;
/// <summary>标签字符串→索引字典索引从1开始0为CTC空白符</summary>
private readonly IReadOnlyDictionary<string, int> _labelDict;
/// <summary>按标签索引的权重数组,用于加权推理分数;为 null 时不加权</summary>
private readonly float[]? _weights;
/// <summary>目标字符串→标签索引序列的 LRU 缓存,加速重复查询</summary>
private readonly CacheHelper.LruCache<string, int[]> _targetCache = new(128);
/// <summary>
/// ONNX 推理输出的命名张量结构,替代匿名元组 (int[], float[])。
/// </summary>
private readonly record struct TensorResult(int Batch, int TimeSteps, int LabelCount, float[] Data);
public Rec(
BgiOnnxModel model,
IReadOnlyList<string> labels,
OcrVersionConfig config,
BgiOnnxFactory bgiOnnxFactory,
bool allowDuplicateChar = false,
Dictionary<string, float>? extraWeights = null,
float threshold = 0.5f)
{
_session = bgiOnnxFactory.CreateInferenceSession(model, true);
_labels = labels;
_config = config;
_allowDuplicateChar = allowDuplicateChar;
_threshold = threshold;
_labelDict = OcrUtils.CreateLabelDict(labels, out var labelLengths);
_labelLengths = labelLengths;
_weights = extraWeights is { Count: > 0 }
? OcrUtils.CreateWeights(extraWeights, _labelDict, labels.Count)
: null;
}
public void Dispose()
{
@@ -33,42 +77,14 @@ public class Rec(
GC.SuppressFinalize(this);
}
~Rec()
{
lock (_session)
{
_session.Dispose();
}
}
/// <summary>
/// Run OCR recognition on multiple images in batches.
/// 对多张图像按批次执行 OCR 识别。
/// </summary>
/// <param name="srcs">Array of images for OCR recognition.</param>
/// <param name="batchSize">Size of the batch to run OCR recognition on.</param>
/// <returns>Array of <see cref="OcrRecognizerResult" /> instances corresponding to OCR recognition results of the images.</returns>
public OcrRecognizerResult[] Run(Mat[] srcs, int batchSize = 0)
{
if (srcs.Length == 0) return [];
var chooseBatchSize = batchSize != 0 ? batchSize : Math.Min(8, Environment.ProcessorCount);
return srcs
.Select((x, i) => (mat: x, i))
.OrderBy(x => x.mat.Width)
.Chunk(chooseBatchSize)
.Select(x => (result: RunMulti(x.Select(x1 => x1.mat).ToArray()), ids: x.Select(x1 => x1.i).ToArray()))
.SelectMany(x => x.result.Zip(x.ids, (result, i) => (result, i)))
.OrderBy(x => x.i)
.Select(x => x.result)
.ToArray();
}
=> RunBatch(srcs, RunMulti, batchSize);
public OcrRecognizerResult Run(Mat src)
{
return RunMulti([src]).Single();
}
=> RunMulti([src]).Single();
private OcrRecognizerResult[] RunMulti(Mat[] srcs)
{
@@ -81,17 +97,173 @@ public class Rec(
throw new ArgumentException($"src[{i}] size should not be 0, wrong input picture provided?");
}
var modelHeight = config.Shape.Height;
var resultTensors = RunInference(srcs);
return resultTensors.SelectMany(tensor =>
{
GCHandle dataHandle = default;
try
{
dataHandle = GCHandle.Alloc(tensor.Data, GCHandleType.Pinned);
var dataPtr = dataHandle.AddrOfPinnedObject();
return Enumerable.Range(0, tensor.Batch)
.Select(i =>
{
StringBuilder sb = new();
var lastIndex = 0;
float score = 0;
var maxIdx = new int[2];
using var fullMat = Mat.FromPixelData(tensor.TimeSteps, tensor.LabelCount,
MatType.CV_32FC1,
dataPtr + i * tensor.TimeSteps * tensor.LabelCount * sizeof(float));
for (var n = 0; n < tensor.TimeSteps; ++n)
{
using var row = fullMat.Row(n);
row.MinMaxIdx(out _, out var maxVal, [], maxIdx);
if (maxIdx[1] > 0 && maxVal >= _threshold && (_allowDuplicateChar || !(n > 0 && maxIdx[1] == lastIndex)))
{
score += (float)maxVal;
sb.Append(OcrUtils.GetLabelByIndex(maxIdx[1], _labels));
}
lastIndex = maxIdx[1];
}
var text = sb.ToString();
return new OcrRecognizerResult(text, text.Length > 0 ? score / text.Length : 0);
})
.ToArray();
}
finally
{
if (dataHandle.IsAllocated) dataHandle.Free();
}
}).ToArray();
}
/// <summary>
/// 将目标字符串转换为标签索引序列,利用 LRU 缓存加速重复查询。
/// 无法映射到标签的字符会被跳过。
/// </summary>
public int[] GetTarget(string target)
{
if (_targetCache.TryGet(target, out var cached) && cached is not null)
return cached;
var result = OcrUtils.MapStringToLabelIndices(target, _labelDict, _labelLengths);
_targetCache.Set(target, result);
return result;
}
/// <summary>
/// 对一批图像执行模糊匹配,返回与目标字符串的最大平均置信度 (0~1)。
/// </summary>
/// <param name="srcs">待匹配图像数组</param>
/// <param name="target">目标字符串</param>
/// <param name="batchSize">每批推理图像数0表示自动</param>
public double RunMatch(Mat[] srcs, string target, int batchSize = 0)
{
if (srcs.Length == 0) return 0;
var targetIndexes = GetTarget(target);
if (targetIndexes.Length == 0) return 0;
var charLevelResults = RunBatch(srcs,
mats => ProcessForMatch(RunInference(mats), targetIndexes), batchSize);
return GetMaxScoreFlat(charLevelResults, targetIndexes);
}
/// <summary>
/// 从 ONNX 原始输出张量中提取目标字符在每个时间步的置信度。
/// <para>
/// 与 RunMulti标准 OCR不同此方法不做 argmaxMinMaxIdx
/// 而是按目标字符的 label 索引直接查找对应位置的原始置信度。
/// 这样即使目标字符不是某个时间步的最高置信度候选DP 仍然能拿到其实际分数进行匹配。
/// </para>
/// </summary>
/// <param name="resultTensors">RunInference 返回的 (shape, data) 张量数组</param>
/// <param name="targetIndexes">目标字符串映射后的 label 索引序列</param>
/// <returns>每张图像对应一个 (labelIndex, confidence) 数组,供 DP 匹配使用</returns>
private (int, float)[][] ProcessForMatch(TensorResult[] resultTensors, int[] targetIndexes)
{
// 目标字符去重(排除 CTC 空白符 index=0
var targetSet = new HashSet<int>(targetIndexes);
targetSet.Remove(0);
return resultTensors.Select(tensor =>
{
var chars = new List<(int, float)>();
for (var n = 0; n < tensor.TimeSteps * tensor.Batch; n++)
{
// 直接按索引查找目标字符的置信度,而非对整行取 argmax
var rowOffset = n * tensor.LabelCount;
foreach (var labelIdx in targetSet)
{
if (labelIdx >= tensor.LabelCount) continue;
var raw = tensor.Data[rowOffset + labelIdx];
var confidence = _weights is not null
? raw * _weights[labelIdx]
: raw;
if (confidence > _threshold)
chars.Add((labelIdx, confidence));
}
}
return chars.ToArray();
}).ToArray();
}
/// <summary>
/// 将多张图像的字符级别结果展平后,计算与 target 的最大匹配分数。
/// 分母使用 target.Length得到的是每个目标字符的平均置信度 (0~1)。
/// </summary>
private static double GetMaxScoreFlat((int, float)[][] result, int[] target)
{
var flatResult = result.SelectMany(x => x).ToArray();
return OcrUtils.GetMaxScoreDp(flatResult, target, target.Length);
}
/// <summary>
/// 通用批处理:按宽度排序、分批推理、恢复原始顺序
/// </summary>
private T[] RunBatch<T>(Mat[] srcs, Func<Mat[], T[]> process, int batchSize = 0)
{
if (srcs.Length == 0) return [];
var chooseBatchSize = batchSize != 0 ? batchSize : Math.Min(8, Environment.ProcessorCount);
return srcs
.Select((x, i) => (mat: x, i))
.OrderBy(x => x.mat.Width)
.Chunk(chooseBatchSize)
.Select(chunk =>
{
var mats = chunk.Select(x => x.mat).ToArray();
var result = process(mats);
return (result, ids: chunk.Select(x => x.i).ToArray());
})
.SelectMany(x => x.result.Zip(x.ids, (r, i) => (r, i)))
.OrderBy(x => x.i)
.Select(x => x.r)
.ToArray();
}
/// <summary>
/// 执行 ONNX 推理,返回每张图像的原始 (shape, data) 张量
/// </summary>
private TensorResult[] RunInference(Mat[] srcs)
{
var modelHeight = _config.Shape.Height;
var maxWidth = (int)Math.Ceiling(srcs.Max(src =>
{
var size = src.Size();
return 1.0 * size.Width / size.Height * modelHeight;
}));
List<IMemoryOwner<float>> owners = [];
(int[], float[])[] resultTensors;
try
{
resultTensors = srcs
return srcs
// .AsParallel()
.Select(src =>
{
@@ -111,12 +283,10 @@ public class Rec(
{
owners.Add(owner);
}
return result;
}
finally
{
// Only dispose Mats created in this scope
if (channel3 != null && !ReferenceEquals(channel3, src))
{
channel3.Dispose();
@@ -124,75 +294,31 @@ public class Rec(
}
})
.Select(inputTensor =>
{
lock (_session)
{
lock (_session)
{
// 多线程推理会出现问题,加锁解决。
using IDisposableReadOnlyCollection<DisposableNamedOnnxValue> results = _session.Run([
NamedOnnxValue.CreateFromTensor(_session.InputNames[0], inputTensor)
]);
var output = results[0];
if (output.ElementType is not TensorElementType.Float)
throw new Exception($"Unexpected output tensor type: {output.ElementType}");
// 多线程推理会出现问题,加锁解决。
using IDisposableReadOnlyCollection<DisposableNamedOnnxValue> results = _session.Run([
NamedOnnxValue.CreateFromTensor(_session.InputNames[0], inputTensor)
]);
var output = results[0];
if (output.ElementType is not TensorElementType.Float)
throw new Exception($"Unexpected output tensor type: {output.ElementType}");
if (output.ValueType is not OnnxValueType.ONNX_TYPE_TENSOR)
throw new Exception($"Unexpected output tensor value type: {output.ValueType}");
var tensor = output.AsTensor<float>();
// 因为一个已知bug,tensor中内存在dml下使用完后会被释放掉,锁之外的代码会报错
return (tensor.Dimensions.ToArray(), tensor.ToArray());
}
if (output.ValueType is not OnnxValueType.ONNX_TYPE_TENSOR)
throw new Exception($"Unexpected output tensor value type: {output.ValueType}");
var tensor = output.AsTensor<float>();
// 因为一个已知bug,tensor中内存在dml下使用完后会被释放掉,锁之外的代码会报错
var dims = tensor.Dimensions;
return new TensorResult(dims[0], dims[1], dims[2], tensor.ToArray());
}
).ToArray();
}).ToArray();
}
finally
{
owners.ForEach(x => { x.Dispose(); });
}
return resultTensors.SelectMany(resultTensor =>
{
var resultArray = resultTensor.Item2;
var resultShape = resultTensor.Item1;
GCHandle dataHandle = default;
try
{
dataHandle = GCHandle.Alloc(resultArray, GCHandleType.Pinned);
var dataPtr = dataHandle.AddrOfPinnedObject();
var labelCount = resultShape[2];
var charCount = resultShape[1];
return Enumerable.Range(0, resultShape[0])
.Select(i =>
{
StringBuilder sb = new();
var lastIndex = 0;
float score = 0;
for (var n = 0; n < charCount; ++n)
{
using var mat = Mat.FromPixelData(1, labelCount, MatType.CV_32FC1,
dataPtr + (n + i * charCount) * labelCount * sizeof(float));
var maxIdx = new int[2];
mat.MinMaxIdx(out _, out var maxVal, [], maxIdx);
if (maxIdx[1] > 0 && !(n > 0 && maxIdx[1] == lastIndex))
{
score += (float)maxVal;
sb.Append(OcrUtils.GetLabelByIndex(maxIdx[1], labels));
}
lastIndex = maxIdx[1];
}
return new OcrRecognizerResult(sb.ToString(), score / sb.Length);
})
.ToArray();
}
finally
{
dataHandle.Free();
}
}).ToArray();
}
public string GetConfigName => config.Name;
}
public string GetConfigName => _config.Name;
}

View File

@@ -1,4 +1,5 @@
using BetterGenshinImpact.Core.Recognition;
using BetterGenshinImpact.Core.Recognition.OCR;
using BetterGenshinImpact.Core.Simulator;
using BetterGenshinImpact.Core.Simulator.Extensions;
using BetterGenshinImpact.GameTask.Common.BgiVision;
@@ -9,6 +10,7 @@ using BetterGenshinImpact.View.Drawable;
using Microsoft.Extensions.Logging;
using OpenCvSharp;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using System.Threading;
@@ -28,56 +30,17 @@ public class SwitchPartyTask
public async Task<bool> Start(string partyName, CancellationToken ct)
{
bool isInPartyViewUi = false;
var useOcrMatch = TaskContext.Instance().Config.OtherConfig.OcrConfig.UseOcrMatchForPartySwitch;
Logger.LogInformation("尝试切换至队伍: {Name}", partyName);
using var ra1 = CaptureToRectArea();
// 确保进入队伍配置界面
bool isInPartyViewUi = false;
if (!Bv.IsInPartyViewUi(ra1))
{
isInPartyViewUi = true;
// 如果不在主界面,则返回主界面
if (!Bv.IsInMainUi(ra1))
{
await _returnMainUiTask.Start(ct);
await Delay(200, ct);
using var raAfterMain = CaptureToRectArea();
if (!Bv.IsInMainUi(raAfterMain))
{
throw new InvalidOperationException("未能返回主界面");
}
}
// 尝试打开队伍配置页面
const int maxAttempts = 2;
bool isOpened = false;
for (int attempt = 1; attempt <= maxAttempts; attempt++)
{
Simulation.SendInput.SimulateAction(GIActions.OpenPartySetupScreen);
// 考虑加载时间 2s共检查 4.2s,如果失败则抛出异常
for (int i = 0; i < 7; i++) // 检查 7 次
{
await Delay(600, ct);
using var raCheck = CaptureToRectArea();
if (Bv.IsInPartyViewUi(raCheck))
{
isOpened = true;
break;
}
}
if (isOpened)
{
break; // 页面已打开,跳出循环
}
}
if (!isOpened)
{
throw new PartySetupFailedException("未能打开队伍配置界面");
}
await EnsurePartyViewOpen(ra1, ct);
}
await Delay(500, ct);
@@ -85,33 +48,15 @@ public class SwitchPartyTask
using var ra = CaptureToRectArea();
var partyViewBtn = ra.Find(ElementAssets.Instance.PartyBtnChooseView);
// OCR 当前队伍名称(无法单字,中间禁止空格)
var currTeamName = ra.Find(new RecognitionObject
if (!partyViewBtn.IsExist())
{
RecognitionType = RecognitionTypes.Ocr,
RegionOfInterest = new Rect(partyViewBtn.Right, partyViewBtn.Top, (int)(350 * _assetScale),
partyViewBtn.Height)
}).Text;
var tempName = currTeamName
.Replace("\"", "") // 移除所有双引号(核心新增,解决日志里的""问题)
.Replace("\r\n", "") // 清理Windows换行符
.Replace("\r", ""); // 先清理所有双引号,避免引号干扰后续处理
// 核心逻辑:找到第一个换行符(\n)的位置,截断并删除换行+后面所有字符
int firstNewLineIndex = tempName.IndexOf('\n');
if (firstNewLineIndex != -1) // 存在换行符,截取到换行符前
{
tempName = tempName.Substring(0, firstNewLineIndex);
Logger.LogWarning("未找到队伍选择按钮,无法判断当前队伍");
throw new PartySetupFailedException("未找到队伍选择按钮");
}
// 最后统一去首尾所有空白(空格、制表符、回车符\r等得到纯净队伍名
currTeamName = tempName.Trim();
Logger.LogInformation("切换队伍,当前队伍名称: {Text},使用正则表达式规则进行模糊匹配", currTeamName);
if (Regex.IsMatch(currTeamName, partyName))
// 检查当前队伍是否已是目标
if (IsCurrentTeamMatch(ra, partyViewBtn, partyName, useOcrMatch))
{
Logger.LogInformation("当前队伍[{Name}]即为目标队伍,无需切换", currTeamName);
if (isInPartyViewUi)
{
Simulation.SendInput.Keyboard.KeyPress(User32.VK.VK_ESCAPE);
@@ -122,101 +67,76 @@ public class SwitchPartyTask
return true;
}
var menu = await NewRetry.WaitForElementAppear(
ElementAssets.Instance.PartyBtnDelete,
() => partyViewBtn.Click(),// 点击队伍选择按钮
ct,
4,
500
);
if (!menu)
{
throw new PartySetupFailedException("未能打开队伍选择页面");
}
// 打开队伍选择页面
var partyDeleteBtn = await OpenPartyChoosePage(partyViewBtn, ct);
await ScrollToTop(ct);
ImageRegion? switchRa = null;
Region? partyDeleteBtn = null;
using (var ocrRa = CaptureToRectArea())
{
var openPartyChooseSuccess = await NewRetry.WaitForAction(() =>
{
switchRa = ocrRa;
partyDeleteBtn = switchRa.Find(ElementAssets.Instance.PartyBtnDelete);
return partyDeleteBtn.IsExist();
}, ct, 5);
if (!openPartyChooseSuccess || switchRa == null || partyDeleteBtn == null)
{
throw new PartySetupFailedException("未能打开队伍配置界面");
}
}
// 点击到最上方
await Task.Delay(50, ct);
GameCaptureRegion.GameRegion1080PPosClick(700, 125);
await Task.Delay(50, ct);
Simulation.SendInput.Mouse.LeftButtonDown();
await Task.Delay(450, ct);
Simulation.SendInput.Mouse.LeftButtonUp();
await Task.Delay(100, ct);
Rect regionOfInterest = new Rect(0, (int)(80 * _assetScale), partyDeleteBtn.Right, partyDeleteBtn.Top - (int)(80 * _assetScale));
RecognitionObject recognitionObject = new RecognitionObject
// 逐页查找目标队伍
Rect regionOfInterest = new(0, (int)(80 * _assetScale), partyDeleteBtn.Right, partyDeleteBtn.Top - (int)(80 * _assetScale));
var recognitionObject = new RecognitionObject
{
RecognitionType = RecognitionTypes.Ocr,
RegionOfInterest = regionOfInterest,
DrawOnWindow = true,
Name = "队伍名称",
DrawOnWindowPen= System.Drawing.Pens.White
DrawOnWindowPen = System.Drawing.Pens.White
};
// 逐页查找
try
{
for (var i = 0; i < 16; i++) // 6.0版本最多20个队伍
for (var i = 0; i < 16; i++) // 6.0版本最多20个队伍
{
using var page = CaptureToRectArea();
var nameList = page.FindMulti(recognitionObject);
var partySwitchNameRaList = page.FindMulti(recognitionObject);
if (partySwitchNameRaList == null || partySwitchNameRaList.Count <= 0)
if (nameList == null || nameList.Count <= 0)
{
Logger.LogInformation("管理队伍界面文字识别失败");
break;
}
// 当前页存在则直接点击
foreach (var textRegion in partySwitchNameRaList)
// 当前页查找匹配
var (match, score) = FindMatchInPage(page, nameList, partyName, useOcrMatch);
if (match != null)
{
if (Regex.IsMatch(textRegion.Text, partyName))
{
page.ClickTo(textRegion.Right + textRegion.Width, textRegion.Bottom);
await Delay(200, ct);
Logger.LogInformation("切换队伍成功: {Text}", textRegion.Text);
await ConfirmParty(page, ct, isInPartyViewUi);
RunnerContext.Instance.ClearCombatScenes();
return true;
}
page.ClickTo(match.Right + match.Width, match.Bottom);
await Delay(200, ct);
if (useOcrMatch)
Logger.LogInformation("切换队伍成功: {Text}(匹配分数: {Score:F4}", match.Text, score);
else
Logger.LogInformation("切换队伍成功: {Text}", match.Text);
await ConfirmParty(page, ct, isInPartyViewUi);
RunnerContext.Instance.ClearCombatScenes();
return true;
}
Region lowest = partySwitchNameRaList.Where(r => r.X > 35 * _assetScale && r.X < 100 * _assetScale).OrderBy(r => r.Y).Last();
// 判断是否已遍历所有队伍
var lowest = nameList
.Where(r => r.X > 35 * _assetScale && r.X < 100 * _assetScale)
.OrderBy(r => r.Y)
.LastOrDefault();
if (lowest == null)
{
Logger.LogInformation("未找到符合坐标范围的队伍名称,跳过翻页判断");
continue;
}
lowest.DrawSelf("底部的队伍");
if (lowest.Y < 777 * _assetScale) // 如果最底下是空队伍则不会有队伍名,以此判断是否已遍历完成
if (lowest.Y < 777 * _assetScale) // 如果最底下是空队伍则不会有队伍名,以此判断是否已遍历完成
{
Logger.LogInformation("已抵达最后一个队伍");
break;
}
// 点击下一
//
if (i == 0)
{
// #ebe4d8 首次点一下第一个,防止第五个被点击过
// 首次点一下第一个,防止第五个被点击过
page.ClickTo(600 * _assetScale, 200 * _assetScale);
await Task.Delay(300, ct); // 等待动画
await Task.Delay(300, ct);
}
page.ClickTo(regionOfInterest.X + regionOfInterest.Width / 2, lowest.Bottom); // 点击最下方队伍下移
page.ClickTo(regionOfInterest.X + regionOfInterest.Width / 2, lowest.Bottom);
await Delay(400, ct);
}
}
@@ -227,14 +147,181 @@ public class SwitchPartyTask
// 未找到
Logger.LogError("未找到队伍: {Name},返回主界面", partyName);
Logger.LogInformation("如果找不到设定的队伍名,有可能是文字识别效果不佳,请尝试正则表达式");
Logger.LogInformation(useOcrMatch
? "如果找不到设定的队伍名,有可能是文字识别效果不佳,请尝试调整 OcrMatch 模糊匹配阈值"
: "如果找不到设定的队伍名,有可能是文字识别效果不佳,请尝试正则表达式");
await _returnMainUiTask.Start(ct);
return false;
}
/// <summary>
/// 确保队伍配置界面已打开。如果不在主界面则先返回主界面,然后打开队伍配置。
/// </summary>
private async Task EnsurePartyViewOpen(ImageRegion currentScreen, CancellationToken ct)
{
if (!Bv.IsInMainUi(currentScreen))
{
await _returnMainUiTask.Start(ct);
await Delay(200, ct);
using var raMain = CaptureToRectArea();
if (!Bv.IsInMainUi(raMain))
throw new InvalidOperationException("未能返回主界面");
}
const int maxAttempts = 2;
for (int attempt = 1; attempt <= maxAttempts; attempt++)
{
Simulation.SendInput.SimulateAction(GIActions.OpenPartySetupScreen);
for (int i = 0; i < 7; i++) // 考虑加载时间 2s共检查 4.2s
{
await Delay(600, ct);
using var raCheck = CaptureToRectArea();
if (Bv.IsInPartyViewUi(raCheck)) return;
}
}
throw new PartySetupFailedException("未能打开队伍配置界面");
}
/// <summary>
/// 检查当前队伍名称是否匹配目标
/// </summary>
private bool IsCurrentTeamMatch(ImageRegion ra, Region partyViewBtn, string partyName, bool useOcrMatch)
{
var roi = new Rect(partyViewBtn.Right, partyViewBtn.Top, (int)(350 * _assetScale), partyViewBtn.Height);
if (useOcrMatch)
{
var matchService = OcrFactory.PaddleMatch;
var threshold = TaskContext.Instance().Config.OtherConfig.OcrConfig.OcrMatchDefaultThreshold;
using var region = ra.DeriveCrop(roi);
var score = matchService.OcrMatch(region.SrcMat, partyName);
Logger.LogInformation("切换队伍,当前队伍 OcrMatch 分数: {Score:F4},判断阈值: {Threshold}", score, threshold);
if (score >= threshold)
{
Logger.LogInformation("当前队伍即为目标队伍(匹配分数: {Score:F4}),无需切换", score);
return true;
}
return false;
}
var text = CleanOcrText(ra.Find(new RecognitionObject
{
RecognitionType = RecognitionTypes.Ocr,
RegionOfInterest = roi
}).Text);
Logger.LogInformation("切换队伍,当前队伍名称: {Text},使用正则表达式规则进行模糊匹配", text);
if (Regex.IsMatch(text, partyName))
{
Logger.LogInformation("当前队伍[{Name}]即为目标队伍,无需切换", text);
return true;
}
return false;
}
/// <summary>
/// 在当前页的文字区域列表中查找匹配目标的队伍
/// </summary>
private (Region? match, double score) FindMatchInPage(
ImageRegion page, List<Region> textRegions, string partyName, bool useOcrMatch)
{
if (useOcrMatch)
{
var matchService = OcrFactory.PaddleMatch;
var threshold = TaskContext.Instance().Config.OtherConfig.OcrConfig.OcrMatchDefaultThreshold;
Region? bestMatch = null;
double bestScore = 0;
var imgW = page.SrcMat.Width;
var imgH = page.SrcMat.Height;
foreach (var region in textRegions)
{
var cx = Math.Max(0, region.X);
var cy = Math.Max(0, region.Y);
var cw = Math.Min(region.Width, imgW - cx);
var ch = Math.Min(region.Height, imgH - cy);
if (cw <= 0 || ch <= 0)
continue;
using var cropped = page.DeriveCrop(cx, cy, cw, ch);
var score = matchService.OcrMatchDirect(cropped.SrcMat, partyName);
if (score >= threshold && score > bestScore)
{
bestScore = score;
bestMatch = region;
}
}
return (bestMatch, bestScore);
}
foreach (var region in textRegions)
{
if (Regex.IsMatch(region.Text, partyName))
return (region, 0);
}
return (null, 0);
}
/// <summary>
/// 打开队伍选择页面(点击选择按钮并等待加载)
/// </summary>
private static async Task<Region> OpenPartyChoosePage(Region partyViewBtn, CancellationToken ct)
{
var menu = await NewRetry.WaitForElementAppear(
ElementAssets.Instance.PartyBtnDelete,
() => partyViewBtn.Click(),
ct, 4, 500);
if (!menu)
throw new PartySetupFailedException("未能打开队伍选择页面");
Region? partyDeleteBtn = null;
var success = await NewRetry.WaitForAction(() =>
{
using var ocrRa = CaptureToRectArea();
partyDeleteBtn = ocrRa.Find(ElementAssets.Instance.PartyBtnDelete);
return partyDeleteBtn.IsExist();
}, ct, 5);
if (!success || partyDeleteBtn == null)
throw new PartySetupFailedException("未能打开队伍配置界面");
return partyDeleteBtn;
}
/// <summary>
/// 滚动列表到最上方
/// </summary>
private static async Task ScrollToTop(CancellationToken ct)
{
await Task.Delay(50, ct);
GameCaptureRegion.GameRegion1080PPosClick(700, 125);
await Task.Delay(50, ct);
Simulation.SendInput.Mouse.LeftButtonDown();
await Task.Delay(450, ct);
Simulation.SendInput.Mouse.LeftButtonUp();
await Task.Delay(100, ct);
}
/// <summary>
/// 清理 OCR 识别结果中的干扰字符
/// </summary>
private static string CleanOcrText(string? text)
{
if (string.IsNullOrEmpty(text))
return string.Empty;
var cleaned = text.Replace("\"", "").Replace("\r\n", "").Replace("\r", "");
var newLineIndex = cleaned.IndexOf('\n');
if (newLineIndex != -1)
cleaned = cleaned[..newLineIndex];
return cleaned.Trim();
}
private async Task ConfirmParty(ImageRegion page, CancellationToken ct, bool isInPartyViewUi = false)
{
var r1 = Bv.ClickWhiteConfirmButton(page.DeriveCrop(0, page.Height / 4, page.Width / 4, page.Height - page.Height / 4));
Bv.ClickWhiteConfirmButton(page.DeriveCrop(0, page.Height / 4, page.Width / 4, page.Height - page.Height / 4));
var partyChooseUiClosed = await NewRetry.WaitForAction(() =>
{
using var ra2 = CaptureToRectArea();
@@ -244,9 +331,10 @@ public class SwitchPartyTask
{
throw new PartySetupFailedException("选择队伍失败,等待队伍切换超时!");
}
await Delay(200, ct);
using var ra = CaptureToRectArea();
var r2 = Bv.ClickWhiteConfirmButton(ra.DeriveCrop(page.Width - page.Width / 4, page.Height / 4, page.Width / 4, page.Height - page.Height / 4));
Bv.ClickWhiteConfirmButton(ra.DeriveCrop(page.Width - page.Width / 4, page.Height / 4, page.Width / 4, page.Height - page.Height / 4));
await Delay(500, ct);
if (isInPartyViewUi) await _returnMainUiTask.Start(ct);
}

View File

@@ -0,0 +1,103 @@
using System;
using System.Collections.Generic;
namespace BetterGenshinImpact.Helpers;
public abstract class CacheHelper
{
public class LruCache<TKey, TValue> where TKey : notnull where TValue : class
{
private readonly int _capacity;
private readonly TimeSpan? _expireAfter;
private readonly Dictionary<TKey, LinkedListNode<(TKey Key, TValue Value, DateTime ExpireAt)>> _cacheMap;
private readonly LinkedList<(TKey Key, TValue Value, DateTime ExpireAt)> _lruList;
private readonly object _lock = new();
public LruCache(int capacity, TimeSpan? expireAfter = null)
{
ArgumentOutOfRangeException.ThrowIfNegativeOrZero(capacity);
_capacity = capacity;
_expireAfter = expireAfter;
_cacheMap = new Dictionary<TKey, LinkedListNode<(TKey, TValue, DateTime)>>();
_lruList = [];
}
public bool TryGet(TKey key, out TValue? value)
{
lock (_lock)
{
if (_cacheMap.TryGetValue(key, out var node))
{
if (_expireAfter.HasValue && DateTime.UtcNow > node.Value.ExpireAt)
{
_lruList.Remove(node);
_cacheMap.Remove(key);
value = null;
return false;
}
_lruList.Remove(node);
_lruList.AddFirst(node);
value = node.Value.Value;
return true;
}
value = null;
return false;
}
}
public void Set(TKey key, TValue value)
{
lock (_lock)
{
var expireAt = _expireAfter.HasValue ? DateTime.UtcNow.Add(_expireAfter.Value) : default;
if (_cacheMap.TryGetValue(key, out var node))
{
node.Value = (key, value, expireAt);
_lruList.Remove(node);
_lruList.AddFirst(node);
}
else
{
if (_cacheMap.Count >= _capacity)
{
var lru = _lruList.Last;
if (lru != null)
{
_cacheMap.Remove(lru.Value.Key);
_lruList.RemoveLast();
}
}
var newNode = new LinkedListNode<(TKey, TValue, DateTime)>((key, value, expireAt));
_lruList.AddFirst(newNode);
_cacheMap[key] = newNode;
}
}
}
public bool Remove(TKey key)
{
lock (_lock)
{
if (!_cacheMap.TryGetValue(key, out var node)) return false;
_lruList.Remove(node);
_cacheMap.Remove(key);
return true;
}
}
public int Count
{
get { lock (_lock) { return _cacheMap.Count; } }
}
public void Clear()
{
lock (_lock)
{
_cacheMap.Clear();
_lruList.Clear();
}
}
}
}

View File

@@ -21,8 +21,8 @@
<ui:TextBlock Margin="0,0,0,8"
FontTypography="BodyStrong"
Text="软件设置" />
<ui:CardControl Margin="0,0,0,12" Icon="{ui:SymbolIcon Globe24}">
<ui:CardControl Margin="0,0,0,12" Icon="{ui:SymbolIcon Globe24}">
<ui:CardControl.Header>
<Grid>
<Grid.RowDefinitions>
@@ -434,19 +434,19 @@
Width="200"
Margin="0,0,12,0"
VerticalAlignment="Center"
Minimum="0"
Maximum="1"
TickFrequency="0.1"
IsSnapToTickEnabled="True"
Maximum="1"
Minimum="0"
TickFrequency="0.1"
Value="{Binding Config.MaskWindowConfig.TextOpacity, Mode=TwoWay}" />
<TextBlock Grid.Row="0"
Grid.Column="2"
VerticalAlignment="Center"
HorizontalAlignment="Left"
MinWidth="40"
Margin="0,0,36,0"
HorizontalAlignment="Left"
VerticalAlignment="Center"
FontFamily="Cascadia Mono, Consolas, Courier New, monospace"
FontSize="14"
MinWidth="40"
Text="{Binding Config.MaskWindowConfig.TextOpacity, StringFormat=F2}" />
</Grid>
@@ -1004,7 +1004,7 @@
</Grid>
</StackPanel>
</ui:CardExpander>
<!-- 新增手动导入本地脚本仓库功能 -->
<!--<ui:CardControl Margin="0,0,0,12" Icon="{ui:SymbolIcon Folder24}">
<ui:CardControl.Header>
@@ -1034,14 +1034,13 @@
Command="{Binding ImportLocalScriptsRepoZipCommand}"
Content="导入" />
</Grid>
</ui:CardControl.Header>
</ui:CardControl>-->
<!-- 其他设置 -->
<ui:CardExpander Margin="0,0,0,12"
ContentPadding="0">
<ui:CardExpander Margin="0,0,0,12" ContentPadding="0">
<ui:CardExpander.Icon>
<ui:FontIcon Glyph="&#xf141;" Style="{StaticResource FaFontIconStyle}" />
</ui:CardExpander.Icon>
@@ -1179,10 +1178,10 @@
Grid.Column="1"
MinWidth="100"
Margin="0,0,36,0"
DisplayMemberPath="Item2"
ItemsSource="{Binding ServerTimeZones}"
SelectedValue="{Binding Config.OtherConfig.ServerTimeZoneOffset}"
SelectedValuePath="Item1"
DisplayMemberPath="Item2" />
SelectedValuePath="Item1" />
</Grid>
<ui:CardExpander Margin="0,0,0,12"
ContentPadding="0"
@@ -1715,7 +1714,141 @@
</b:Interaction.Triggers>
</ComboBox>
</Grid>
<Grid Margin="16">
<Grid.RowDefinitions>
<RowDefinition Height="Auto" />
<RowDefinition Height="Auto" />
</Grid.RowDefinitions>
<Grid.ColumnDefinitions>
<ColumnDefinition Width="*" />
<ColumnDefinition Width="Auto" />
</Grid.ColumnDefinitions>
<ui:TextBlock Grid.Row="0"
Grid.Column="0"
FontTypography="Body"
Text="允许连续重复字符"
TextWrapping="Wrap" />
<ui:TextBlock Grid.Row="1"
Grid.Column="0"
Foreground="{ui:ThemeResource TextFillColorTertiaryBrush}"
Text="关闭CTC重复字符折叠允许OCR结果中出现连续重复字符。修改后需重新加载OCR引擎才能生效。"
TextWrapping="Wrap" />
<ui:ToggleSwitch Grid.Row="0"
Grid.RowSpan="2"
Grid.Column="1"
Margin="0,0,36,0"
IsChecked="{Binding Config.OtherConfig.OcrConfig.AllowDuplicateChar, Mode=TwoWay}" />
</Grid>
<Grid Margin="16">
<Grid.RowDefinitions>
<RowDefinition Height="Auto" />
<RowDefinition Height="Auto" />
</Grid.RowDefinitions>
<Grid.ColumnDefinitions>
<ColumnDefinition Width="*" />
<ColumnDefinition Width="Auto" />
</Grid.ColumnDefinitions>
<ui:TextBlock Grid.Row="0"
Grid.Column="0"
FontTypography="Body"
Text="切换队伍使用OCR模糊匹配"
TextWrapping="Wrap" />
<ui:TextBlock Grid.Row="1"
Grid.Column="0"
Foreground="{ui:ThemeResource TextFillColorTertiaryBrush}"
Text="开启后使用OcrMatch模糊匹配识别队伍名称关闭则使用正则表达式匹配。"
TextWrapping="Wrap" />
<ui:ToggleSwitch Grid.Row="0"
Grid.RowSpan="2"
Grid.Column="1"
Margin="0,0,36,0"
IsChecked="{Binding Config.OtherConfig.OcrConfig.UseOcrMatchForPartySwitch, Mode=TwoWay}" />
</Grid>
<Grid Margin="16">
<Grid.RowDefinitions>
<RowDefinition Height="Auto" />
<RowDefinition Height="Auto" />
</Grid.RowDefinitions>
<Grid.ColumnDefinitions>
<ColumnDefinition Width="*" />
<ColumnDefinition Width="Auto" />
<ColumnDefinition Width="Auto" />
</Grid.ColumnDefinitions>
<ui:TextBlock Grid.Row="0"
Grid.Column="0"
FontTypography="Body"
Text="OCR模糊匹配阈值"
TextWrapping="Wrap" />
<ui:TextBlock Grid.Row="1"
Grid.Column="0"
Foreground="{ui:ThemeResource TextFillColorTertiaryBrush}"
Text="【0~1】分数 ≥ 阈值视为匹配成功,默认值为 0.8。"
TextWrapping="Wrap" />
<Slider Grid.Row="0"
Grid.RowSpan="2"
Grid.Column="1"
Width="200"
Margin="0,0,12,0"
VerticalAlignment="Center"
IsSnapToTickEnabled="True"
Maximum="1"
Minimum="0.01"
TickFrequency="0.01"
Value="{Binding Config.OtherConfig.OcrConfig.OcrMatchDefaultThreshold, Mode=TwoWay}" />
<TextBlock Grid.Row="0"
Grid.RowSpan="2"
Grid.Column="2"
MinWidth="40"
Margin="0,0,36,0"
HorizontalAlignment="Left"
VerticalAlignment="Center"
FontFamily="Cascadia Mono, Consolas, Courier New, monospace"
FontSize="14"
Text="{Binding Config.OtherConfig.OcrConfig.OcrMatchDefaultThreshold, StringFormat=F2}" />
</Grid>
<Grid Margin="16">
<Grid.RowDefinitions>
<RowDefinition Height="Auto" />
<RowDefinition Height="Auto" />
</Grid.RowDefinitions>
<Grid.ColumnDefinitions>
<ColumnDefinition Width="*" />
<ColumnDefinition Width="Auto" />
<ColumnDefinition Width="Auto" />
</Grid.ColumnDefinitions>
<ui:TextBlock Grid.Row="0"
Grid.Column="0"
FontTypography="Body"
Text="PaddleOCR 识别置信度阈值"
TextWrapping="Wrap" />
<ui:TextBlock Grid.Row="1"
Grid.Column="0"
Foreground="{ui:ThemeResource TextFillColorTertiaryBrush}"
Text="【0~1】低于此阈值的字符将被过滤默认值为 0.5。修改后需重载OCR。"
TextWrapping="Wrap" />
<Slider Grid.Row="0"
Grid.RowSpan="2"
Grid.Column="1"
Width="200"
Margin="0,0,12,0"
VerticalAlignment="Center"
IsSnapToTickEnabled="True"
Maximum="0.99"
Minimum="0"
TickFrequency="0.01"
Value="{Binding Config.OtherConfig.OcrConfig.PaddleOcrThreshold, Mode=TwoWay}" />
<TextBlock Grid.Row="0"
Grid.RowSpan="2"
Grid.Column="2"
MinWidth="40"
Margin="0,0,36,0"
HorizontalAlignment="Left"
VerticalAlignment="Center"
FontFamily="Cascadia Mono, Consolas, Courier New, monospace"
FontSize="14"
Text="{Binding Config.OtherConfig.OcrConfig.PaddleOcrThreshold, StringFormat=F2}" />
</Grid>
</StackPanel>
</ui:CardExpander>

View File

@@ -0,0 +1,195 @@
using BetterGenshinImpact.Core.Recognition.OCR;
using OpenCvSharp;
namespace BetterGenshinImpact.UnitTest.CoreTests.RecognitionTests.OCRTests;
public class OcrMatchFallbackServiceTests
{
#region LevenshteinDistance
[Fact]
public void LevenshteinDistance_IdenticalStrings_ReturnsZero()
{
Assert.Equal(0, OcrMatchFallbackService.LevenshteinDistance("abc", "abc"));
}
[Fact]
public void LevenshteinDistance_EmptyAndNonEmpty_ReturnsLength()
{
Assert.Equal(3, OcrMatchFallbackService.LevenshteinDistance("", "abc"));
Assert.Equal(3, OcrMatchFallbackService.LevenshteinDistance("abc", ""));
}
[Fact]
public void LevenshteinDistance_BothEmpty_ReturnsZero()
{
Assert.Equal(0, OcrMatchFallbackService.LevenshteinDistance("", ""));
}
[Fact]
public void LevenshteinDistance_SingleSubstitution()
{
// "确认" vs "确忍" — 一个字符替换
Assert.Equal(1, OcrMatchFallbackService.LevenshteinDistance("确认", "确忍"));
}
[Fact]
public void LevenshteinDistance_Insertion()
{
Assert.Equal(1, OcrMatchFallbackService.LevenshteinDistance("ac", "abc"));
}
[Fact]
public void LevenshteinDistance_Deletion()
{
Assert.Equal(1, OcrMatchFallbackService.LevenshteinDistance("abc", "ac"));
}
[Fact]
public void LevenshteinDistance_CompletelyDifferent()
{
Assert.Equal(3, OcrMatchFallbackService.LevenshteinDistance("abc", "xyz"));
}
#endregion
#region ComputeTextSimilarity
[Fact]
public void ComputeTextSimilarity_ExactMatch_ReturnsOne()
{
Assert.Equal(1.0, OcrMatchFallbackService.ComputeTextSimilarity("确认", "确认"));
}
[Fact]
public void ComputeTextSimilarity_TextContainsTarget_ReturnsOne()
{
// "确认购买" 包含 "确认"
Assert.Equal(1.0, OcrMatchFallbackService.ComputeTextSimilarity("确认购买", "确认"));
}
[Fact]
public void ComputeTextSimilarity_TargetContainsText_ReturnsRatio()
{
// "确认" 被 "确认购买" 包含,长度比 = 2/4
Assert.Equal(0.5, OcrMatchFallbackService.ComputeTextSimilarity("确认", "确认购买"));
}
[Fact]
public void ComputeTextSimilarity_EmptyTarget_ReturnsOne()
{
Assert.Equal(1.0, OcrMatchFallbackService.ComputeTextSimilarity("任意文字", ""));
}
[Fact]
public void ComputeTextSimilarity_EmptyText_ReturnsZero()
{
Assert.Equal(0.0, OcrMatchFallbackService.ComputeTextSimilarity("", "确认"));
}
[Fact]
public void ComputeTextSimilarity_SingleCharDifference()
{
// "确忍" vs "确认" — 距离1, 最大长度2, 相似度 = 1 - 1/2 = 0.5
Assert.Equal(0.5, OcrMatchFallbackService.ComputeTextSimilarity("确忍", "确认"));
}
[Fact]
public void ComputeTextSimilarity_CompletelyDifferent_ReturnsZero()
{
// 完全不同的字符串
Assert.Equal(0.0, OcrMatchFallbackService.ComputeTextSimilarity("甲乙", "丙丁"));
}
#endregion
#region OcrMatch / OcrMatchDirect 使 FakeOcrService
[Fact]
public void OcrMatch_WhenRegionContainsTarget_ReturnsOne()
{
var fakeOcr = new FakeOcrService(new OcrResult([
new OcrResultRegion(default, "确认购买", 0.9f)
]));
var sut = new OcrMatchFallbackService(fakeOcr);
using var mat = new Mat(50, 200, MatType.CV_8UC3, Scalar.White);
var score = sut.OcrMatch(mat, "确认");
Assert.Equal(1.0, score);
}
[Fact]
public void OcrMatch_MultipleRegions_ReturnsBestScore()
{
var fakeOcr = new FakeOcrService(new OcrResult([
new OcrResultRegion(default, "其他文字", 0.9f),
new OcrResultRegion(default, "确认", 0.9f)
]));
var sut = new OcrMatchFallbackService(fakeOcr);
using var mat = new Mat(50, 200, MatType.CV_8UC3, Scalar.White);
var score = sut.OcrMatch(mat, "确认");
Assert.Equal(1.0, score);
}
[Fact]
public void OcrMatch_NoRegions_ReturnsZero()
{
var fakeOcr = new FakeOcrService(new OcrResult([]));
var sut = new OcrMatchFallbackService(fakeOcr);
using var mat = new Mat(50, 200, MatType.CV_8UC3, Scalar.White);
var score = sut.OcrMatch(mat, "确认");
Assert.Equal(0.0, score);
}
[Fact]
public void OcrMatchDirect_ExactMatch_ReturnsOne()
{
var fakeOcr = new FakeOcrService(ocrWithoutDetectorResult: "确认");
var sut = new OcrMatchFallbackService(fakeOcr);
using var mat = new Mat(50, 200, MatType.CV_8UC3, Scalar.White);
var score = sut.OcrMatchDirect(mat, "确认");
Assert.Equal(1.0, score);
}
[Fact]
public void OcrMatchDirect_PartialMatch_ReturnsPartialScore()
{
var fakeOcr = new FakeOcrService(ocrWithoutDetectorResult: "确忍");
var sut = new OcrMatchFallbackService(fakeOcr);
using var mat = new Mat(50, 200, MatType.CV_8UC3, Scalar.White);
var score = sut.OcrMatchDirect(mat, "确认");
Assert.Equal(0.5, score, 0.01);
}
#endregion
/// <summary>
/// 用于测试 OcrMatchFallbackService 的假 IOcrService。
/// </summary>
private class FakeOcrService : IOcrService
{
private readonly OcrResult? _ocrResult;
private readonly string _ocrWithoutDetectorResult;
public FakeOcrService(OcrResult? ocrResult = null, string ocrWithoutDetectorResult = "")
{
_ocrResult = ocrResult;
_ocrWithoutDetectorResult = ocrWithoutDetectorResult;
}
public string Ocr(Mat mat) => _ocrResult?.Text ?? "";
public string OcrWithoutDetector(Mat mat) => _ocrWithoutDetectorResult;
public OcrResult OcrResult(Mat mat) => _ocrResult ?? new OcrResult([]);
}
}

View File

@@ -0,0 +1,291 @@
using BetterGenshinImpact.Core.Recognition.OCR.Engine;
namespace BetterGenshinImpact.UnitTest.CoreTests.RecognitionTests.OCRTests;
public class OcrUtilsTests
{
#region CreateLabelDict
[Fact]
public void CreateLabelDict_SingleCharLabels_MapsCorrectly()
{
// 标签 ["a","b","c"] → a=1, b=2, c=3, " "=4
IReadOnlyList<string> labels = ["a", "b", "c"];
var dict = OcrUtils.CreateLabelDict(labels, out var lengths);
Assert.Equal(1, dict["a"]);
Assert.Equal(2, dict["b"]);
Assert.Equal(3, dict["c"]);
Assert.Equal(4, dict[" "]);
// 所有标签都是长度1labelLengths = [1]
Assert.Single(lengths);
Assert.Equal(1, lengths[0]);
}
[Fact]
public void CreateLabelDict_NoZeroLength()
{
// 不应包含长度为0的项防止无限循环
IReadOnlyList<string> labels = ["x", "y"];
OcrUtils.CreateLabelDict(labels, out var lengths);
Assert.DoesNotContain(0, lengths);
}
[Fact]
public void CreateLabelDict_LengthsDescendingOrder()
{
// 多字节标签时labelLengths 应降序排列(先试长匹配)
IReadOnlyList<string> labels = ["a", "ab", "b"];
OcrUtils.CreateLabelDict(labels, out var lengths);
for (var i = 0; i < lengths.Length - 1; i++)
{
Assert.True(lengths[i] >= lengths[i + 1], "labelLengths 应为降序");
}
}
#endregion
#region MapStringToLabelIndices
[Fact]
public void MapStringToLabelIndices_SimpleMatch()
{
// labels: ["a","b","c"] → a=1, b=2, c=3
IReadOnlyList<string> labels = ["a", "b", "c"];
var dict = OcrUtils.CreateLabelDict(labels, out var lengths);
var result = OcrUtils.MapStringToLabelIndices("abc", dict, lengths);
Assert.Equal([1, 2, 3], result);
}
[Fact]
public void MapStringToLabelIndices_SkipsUnknownChars()
{
// "aXb" 中 X 不在标签里,应被跳过
IReadOnlyList<string> labels = ["a", "b"];
var dict = OcrUtils.CreateLabelDict(labels, out var lengths);
var result = OcrUtils.MapStringToLabelIndices("aXb", dict, lengths);
Assert.Equal([1, 2], result);
}
[Fact]
public void MapStringToLabelIndices_PrefersLongerMatch()
{
// 标签含 "ab" 和 "a",输入 "ab" 应优先匹配长标签 "ab"
IReadOnlyList<string> labels = ["a", "ab", "b"];
var dict = OcrUtils.CreateLabelDict(labels, out var lengths);
var result = OcrUtils.MapStringToLabelIndices("ab", dict, lengths);
// "ab" 整体匹配为 index 2labels 中第2个元素
Assert.Single(result);
Assert.Equal(2, result[0]);
}
[Fact]
public void MapStringToLabelIndices_EmptyString_ReturnsEmpty()
{
IReadOnlyList<string> labels = ["a", "b"];
var dict = OcrUtils.CreateLabelDict(labels, out var lengths);
var result = OcrUtils.MapStringToLabelIndices("", dict, lengths);
Assert.Empty(result);
}
[Fact]
public void MapStringToLabelIndices_AllUnknown_ReturnsEmpty()
{
IReadOnlyList<string> labels = ["a", "b"];
var dict = OcrUtils.CreateLabelDict(labels, out var lengths);
var result = OcrUtils.MapStringToLabelIndices("XYZ", dict, lengths);
Assert.Empty(result);
}
[Fact]
public void MapStringToLabelIndices_SpaceChar_MapsToSpaceIndex()
{
// 空格字符映射到 labels.Count + 1
IReadOnlyList<string> labels = ["a", "b"];
var dict = OcrUtils.CreateLabelDict(labels, out var lengths);
var result = OcrUtils.MapStringToLabelIndices("a b", dict, lengths);
// a=1, " "=3, b=2
Assert.Equal([1, 3, 2], result);
}
#endregion
#region GetMaxScoreDP
[Fact]
public void GetMaxScoreDP_PerfectMatch_ReturnsFullScore()
{
// result 中按顺序包含 target 的所有元素,置信度均为 1.0
(int, float)[] result = [(1, 1.0f), (2, 1.0f), (3, 1.0f)];
int[] target = [1, 2, 3];
var score = OcrUtils.GetMaxScoreDp(result, target, target.Length);
Assert.Equal(1.0, score);
}
[Fact]
public void GetMaxScoreDP_NoMatch_ReturnsZero()
{
// result 中不包含 target 的任何元素
(int, float)[] result = [(4, 1.0f), (5, 1.0f)];
int[] target = [1, 2];
var score = OcrUtils.GetMaxScoreDp(result, target, target.Length);
Assert.Equal(0, score);
}
[Fact]
public void GetMaxScoreDP_EmptyTarget_ReturnsZero()
{
(int, float)[] result = [(1, 1.0f)];
int[] target = [];
var score = OcrUtils.GetMaxScoreDp(result, target, 1);
Assert.Equal(0, score);
}
[Fact]
public void GetMaxScoreDP_PartialMatch_ReturnsZero()
{
// target 需要 [1,2,3],但 result 只有 [1,2],无法完整匹配
(int, float)[] result = [(1, 1.0f), (2, 1.0f)];
int[] target = [1, 2, 3];
var score = OcrUtils.GetMaxScoreDp(result, target, target.Length);
Assert.Equal(0, score);
}
[Fact]
public void GetMaxScoreDP_SubsequenceMatch_SkipsNoise()
{
// result 中有噪声,但子序列 [1,2,3] 可匹配
(int, float)[] result = [(9, 0.5f), (1, 0.8f), (9, 0.3f), (2, 0.9f), (3, 0.7f)];
int[] target = [1, 2, 3];
var score = OcrUtils.GetMaxScoreDp(result, target, target.Length);
// (0.8 + 0.9 + 0.7) / 3 = 0.8
Assert.Equal(0.8, score, 0.01);
}
[Fact]
public void GetMaxScoreDP_PicksBestConfidence()
{
// target [1]result 中有两个 index=1应选置信度最高的
(int, float)[] result = [(1, 0.3f), (1, 0.9f)];
int[] target = [1];
var score = OcrUtils.GetMaxScoreDp(result, target, 1);
Assert.Equal(0.9, score, 0.01);
}
[Fact]
public void GetMaxScoreDP_NormalizesWithAvailableCount()
{
// availableCount > target.Length 时分数被稀释
(int, float)[] result = [(1, 1.0f), (2, 1.0f)];
int[] target = [1, 2];
var score = OcrUtils.GetMaxScoreDp(result, target, 4);
// (1.0 + 1.0) / 4 = 0.5
Assert.Equal(0.5, score, 0.01);
}
[Fact]
public void GetMaxScoreDP_ManyFrames_TargetLengthDenominator_ScoresHigh()
{
// 模拟多个文字区域的字符帧合并后做匹配,分母应为 target.Length
// 即使有很多噪声帧,只要 target 完整匹配,分数仍应很高
(int, float)[] result = [
(9, 0.5f), (8, 0.6f), (7, 0.4f), // 噪声区域1
(1, 0.9f), (2, 0.85f), // 匹配目标 [1,2]
(6, 0.7f), (5, 0.3f), (4, 0.5f), // 噪声区域2
(9, 0.2f), (8, 0.4f) // 噪声区域3
];
int[] target = [1, 2];
// 使用 target.Length 作为分母:(0.9 + 0.85) / 2 = 0.875
var score = OcrUtils.GetMaxScoreDp(result, target, target.Length);
Assert.Equal(0.875, score, 0.01);
}
#endregion
#region CreateWeights
[Fact]
public void CreateWeights_DefaultsToOne()
{
IReadOnlyList<string> labels = ["a", "b", "c"];
var labelDict = OcrUtils.CreateLabelDict(labels, out _);
var weights = OcrUtils.CreateWeights(new Dictionary<string, float>(), labelDict, labels.Count);
// labels.Count + 2 = 5
Assert.Equal(5, weights.Length);
Assert.All(weights, w => Assert.Equal(1.0f, w));
}
[Fact]
public void CreateWeights_AppliesExtraWeights()
{
IReadOnlyList<string> labels = ["a", "b", "c"];
var extra = new Dictionary<string, float> { { "b", 2.5f } };
var labelDict = OcrUtils.CreateLabelDict(labels, out _);
var weights = OcrUtils.CreateWeights(extra, labelDict, labels.Count);
// "b" 是 labels[1]index=2
Assert.Equal(1.0f, weights[1]); // "a"
Assert.Equal(2.5f, weights[2]); // "b"
Assert.Equal(1.0f, weights[3]); // "c"
}
[Fact]
public void CreateWeights_IgnoresUnknownKeys()
{
IReadOnlyList<string> labels = ["a", "b"];
var extra = new Dictionary<string, float> { { "z", 5.0f } };
var labelDict = OcrUtils.CreateLabelDict(labels, out _);
var weights = OcrUtils.CreateWeights(extra, labelDict, labels.Count);
Assert.All(weights, w => Assert.Equal(1.0f, w));
}
[Fact]
public void CreateWeights_SpaceKey_MapsToCorrectIndex()
{
// 空格权重应写入 labels.Count + 1 位置,与 CreateLabelDict 一致
IReadOnlyList<string> labels = ["a", " ", "b"];
var extra = new Dictionary<string, float> { { " ", 3.0f } };
var labelDict = OcrUtils.CreateLabelDict(labels, out _);
var weights = OcrUtils.CreateWeights(extra, labelDict, labels.Count);
// labels.Count + 1 = 4空格权重应在 weights[4]
Assert.Equal(3.0f, weights[labels.Count + 1]);
// labels 中 " " 的位置 index=2即 weights[2])不应被错误写入
Assert.Equal(1.0f, weights[2]);
}
#endregion
}

View File

@@ -0,0 +1,74 @@
using BetterGenshinImpact.Helpers;
namespace BetterGenshinImpact.UnitTest.HelperTests;
public class LruCacheTests
{
[Fact]
public void BasicSetGetTest()
{
var cache = new CacheHelper.LruCache<string, string>(3);
cache.Set("a", "1");
cache.Set("b", "2");
cache.Set("c", "3");
Assert.True(cache.TryGet("a", out var v1) && v1 == "1");
Assert.True(cache.TryGet("b", out var v2) && v2 == "2");
Assert.True(cache.TryGet("c", out var v3) && v3 == "3");
}
[Fact]
public void LruEvictionTest()
{
var cache = new CacheHelper.LruCache<string, string>(2);
cache.Set("a", "1");
cache.Set("b", "2");
cache.Set("c", "3"); // "a" 应被淘汰
Assert.False(cache.TryGet("a", out _));
Assert.True(cache.TryGet("b", out var v2) && v2 == "2");
Assert.True(cache.TryGet("c", out var v3) && v3 == "3");
}
[Fact]
public void UpdateMovesToHeadTest()
{
var cache = new CacheHelper.LruCache<string, string>(2);
cache.Set("a", "1");
cache.Set("b", "2");
cache.TryGet("a", out _); // a 变为最新
cache.Set("c", "3"); // b 应被淘汰
Assert.True(cache.TryGet("a", out _));
Assert.False(cache.TryGet("b", out _));
Assert.True(cache.TryGet("c", out _));
}
[Fact]
public void ExpireTest()
{
var cache = new CacheHelper.LruCache<string, string>(2, TimeSpan.FromMilliseconds(500));
cache.Set("a", "1");
Assert.True(cache.TryGet("a", out var v) && v == "1");
Thread.Sleep(650);
Assert.False(cache.TryGet("a", out _));
}
[Fact]
public void RemoveTest()
{
var cache = new CacheHelper.LruCache<string, string>(2);
cache.Set("a", "1");
Assert.True(cache.Remove("a"));
Assert.False(cache.TryGet("a", out _));
}
[Fact]
public void ClearTest()
{
var cache = new CacheHelper.LruCache<string, string>(2);
Assert.Equal(0, cache.Count);
cache.Set("a", "1");
cache.Set("b", "2");
Assert.Equal(2, cache.Count);
cache.Clear();
Assert.Equal(0, cache.Count);
}
}