Compare commits

...

4 Commits

Author SHA1 Message Date
luguoyixiazi
a6a53bb46b Update README.md 2025-11-04 18:30:34 +08:00
luguoyixiazi
e98e47a0af Update README.md 2025-11-04 17:22:21 +08:00
luguoyixiazi
84d225dc14 Update README.md 2025-11-04 17:20:25 +08:00
luguoyixiazi
b4d35b9906 新模型及fullpage.9.2.0-guwyxh.js适配
Fixes #5
添加dinov3、任务头及yolo11n用于通过二者的验证码
同时更新fullpage.9.2.0-guwyxh.js中新的常量
``` json
{
captcha_token":"2064329542",
"tsfq":"xovrayel"
}
```
2025-11-04 16:40:02 +08:00
4 changed files with 685 additions and 82 deletions

View File

@@ -8,19 +8,28 @@
## 参考项目
模型及V4数据集https://github.com/taisuii/ClassificationCaptchaOcr
resnet模型及V4数据集https://github.com/taisuii/ClassificationCaptchaOcr
点选检测模型https://github.com/Peterande/D-FINE
另一个点选检测模型https://github.com/facebookresearch/dinov3
另一个点选的检测模型https://github.com/ultralytics/ultralytics
apihttps://github.com/ravizhan/geetest-v3-click-crack
感谢 @kissnavel 提供的关于我没玩的那两款的api
## 运行步骤
### 1.安装依赖本地必选使用docker跳至[5-b](#docker)
可选a.如果要训练paddle的话还得安装paddlex及图像分类模块安装看项目https://github.com/PaddlePaddle/PaddleX
(可选)b.d-fine训练看项目https://github.com/Peterande/D-FINE
* 必选c.模型需要在项目目录下新建一个model文件夹然后把模型文件放进去具体命名可以是resnet18.onnx或者PP-HGNetV2-B4.onnx默认使用PP-HGNetV2-B4模型如果用resnet则use_v3_model设置为False因为模型的输入输出不一样可以自行修改d-fine的模型同样放置此路径用于通过点选
可选a-0. 如果要训练paddle的话还得安装paddlex及图像分类模块安装看项目[https://github.com/PaddlePaddle/PaddleX](https://github.com/PaddlePaddle/PaddleX)
(可选)a-1. d-fine训练看项目[https://github.com/Peterande/D-FINE](https://github.com/Peterande/D-FINE)
可选a-2. dinov3使用看项目[https://github.com/Peterande/D-FINE](https://github.com/facebookresearch/dinov3) 具体分类基于patch token
可选a-3. yolo训练看项目[https://github.com/Peterande/D-FINE](https://github.com/ultralytics/ultralytics)
可选a-4. 基于dinov3的分类架构极其简陋从一组图片中抽取所有已定位的图片的特征图而后在一组中以下侧作为锚点设置1正面样本其余均做负面样本学习dinov3提取特征中影响相似度比对的维度
* 必选b.模型需要在项目目录下新建一个model文件夹然后把模型文件放进去具体命名可以是resnet18.onnx或者PP-HGNetV2-B4.onnx默认使用PP-HGNetV2-B4模型如果用resnet则use_v3_model设置为False因为模型的输入输出不一样可以自行修改d-fine的模型同样放置此路径用于通过点选另外的dinov3和yolo1n和我自己做的任务头整体组合为一个流水线作业全都要放进去
```
pip install -r requirements.txt
@@ -38,7 +47,7 @@ pip install -r requirements_without_train.txt
- 数据集详情参考上面标注的项目但是上面项目是V4数据集V3没有demo自行发挥吧用V4练V3不改代码正确率有点感人
- 主要是V4的尺寸和V3有差别V4的api直接给两张图一张是目标图一张是九宫格V3放在一起要切目标且V3目标图清晰度很低V4九宫格切了之后是100 * 86的图去掉黑边但是V3九宫格切的是112 * 112不确定V4九宫格内容在V3基础上做了什么变换反正改预处理就完事了
##### b. 训练PP-HGNetV2-B4可选
##### a-0. 训练PP-HGNetV2-B4可选
在paddle上随便找的数据集格式如下如果拿V4练V3建议是多整点变换
@@ -50,17 +59,22 @@ pip install -r requirements_without_train.txt
└─验证集和测试集同上
```
##### b. 训练d-fine可选
##### a-1. 训练d-fine可选
数据集格式如d-fine中标识如果不修改源码则num_classes需+1采用coco格式即可我用的320*320dataloader注释掉了RandomZoomOut、RandomHorizontalFlip、RandomIoUCrop这些我全写在数据集生成中了
##### c. 如果要切V3的九宫格图用crop_image.py的crop_image_v3切V4则使用crop_image自行编写切图脚本
##### a-2. 训练基于dinov3和yolo的流水线可选
数据集格式我自己拟定的因为一次加载使用一个锚点一个正面多个负面所以我先组织了格式当然原始标注时还是用的coco将每个image上的所有标注全组合在两个列表根据y轴划分top和bottom以bottom为锚点实际上没差但是实际使用不涉及top组内比对就跳过了节省资源top中1✅多❌
##### b. 如果要切V3的九宫格图用crop_image.py的crop_image_v3切V4则使用crop_image自行编写切图脚本
### 3.训练模型(可选)
- 训练resnet18运行 `python train.py`
- 如果训练PP-HGNetV2-B4运行`python train_paddle.py`
- 训练d-fine参照原项目一个模型拿下比ddddocr+相似性检测资源开销小点
- 训练代码极其简单可以基于我提出的思路即数据集组织使用llm直接生成yolo训练参考其仓库即可dinov3没训练也没拿来做定位因为定位头的规模大于yolo11n不划算
### 4-a.PP-HGNetV2-B4模型和resnet模型转换为onnx可选
@@ -71,18 +85,30 @@ pip install -r requirements_without_train.txt
- 依原项目转换
- 推理时图像预处理应于训练时一致d-fine仓库中onnx推理的预处理和训练不一致……
### 4-c. 复合模型转onnx可选
- yolo依原项目转换
- dinov3直接使用官方即可只是拿来提取特征
- 分类的导出也是pytorch和onnx内置函数由于太小所以没必要做什么操作
- 之所以不组合为一个模型主要还是考虑到未来如果更新可以复用dinov3提取的特征实际上可以完成上述所有任务虽然定位的模型会大一点但一个几百kb大小的头可以替代九宫格的分类模型点选定位如果只划分上下的部分或许也能用更小的模型
### 5-a.启动fastapi服务必须要有训练完成的onnx格式模型
运行 `python main.py`默认用的paddle的onnx模型如果要用resnet18可以自己改注释或者`uvicorn main:app --host 0.0.0.0 --port 9645 --reload`
基于环境变量可以自定义加载哪些模型目前支持use_pdluse_dfineuse_multi日志也有环境变量改LOG_LEVEL
由于轨迹问题可能会出现验证正确但是结果失败所以建议增加retry次数训练后的paddle模型正确率在99.9%以上
运行 `python main.py`九宫格默认用的paddle的onnx模型如果要用resnet18可以自己改注释或者`uvicorn main:app --host 0.0.0.0 --port 9645 --reload`
由于轨迹问题可能会出现验证正确但是结果失败所以建议增加retry次数训练后的paddle模型正确率在99.9%以上,
由于dinov3的onnx和pth不完全一致导致基于pth训练的头精度低了一点4613张图里面错了1张如果采用量化的dinov3会多错几张当然可以使用onnx量化提取的特征训练但是速度比pth慢得多就是了再说吧
### 5-b.使用docker启动服务
镜像地址为<span id="docker">luguoyixiazi/test_nine:25.7.2</span>
镜像地址为<span id="docker">luguoyixiazi/test_nine:25.7.2</span> 此版本特供只需要九宫格和便笺的
运行时只需指定绑定的port和两个环境变量`use_pdl``use_dfine`1为启用模型0为不启用默认均启用api端口为/pass_uni必填参数gt、challenge单独的pass_nine和pass_icon也写了有更多可选参数
镜像地址为<span id="docker">luguoyixiazi/test_nine:25.11.2</span> 此版本在上述基础上支持🛤和3z
运行时只需指定绑定的port和两个环境变量`use_pdl``use_dfine``use_multi`1为启用模型0为不启用默认均启用api端口为/pass_uni必填参数gt、challenge单独的pass_nine和pass_icon也写了有更多可选参数
### 6.api调用
@@ -103,8 +129,13 @@ def game_captcha(gt: str, challenge: str):
在snap hutao中的服务端口为/pass_hutao,返回值已做对齐填写api如`(http://127.0.0.1:9645/pass_hutao?gt={0}&challenge={1})`即可
添加了辅助api加载卸载模型和查看模型以及更改日志等级
具体调用代码看使用项目此处示例仅为API url和参数示例
#### --宣传--
欢迎大家支持我的其他项目(搭配使用)喵~~~~~~~~

View File

@@ -1,16 +1,19 @@
import os
import hashlib
import json
import math
import random
import time
import logging
import httpx
from os import path as PATH
from crop_image import validate_path
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from crop_image import validate_path
from os import path as PATH
import os
logger = logging.getLogger(__name__)
class Crack:
def __init__(self, gt=None, challenge=None):
self.pic_path = None
@@ -39,6 +42,7 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
url = f"https://api.geetest.com/gettype.php?gt={self.gt}"
res = self.session.get(url)
data = json.loads(res.text[1:-1])["data"]
logger.debug(f"再次获得{data}")
return data
@staticmethod
@@ -358,6 +362,7 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
}
resp = self.session.get("https://api.geetest.com/get.php", params=params).text
data = json.loads(resp[22:-1])["data"]
logger.debug(f"获取cs结果{data}")
self.c = data["c"]
self.s = data["s"]
return data["c"], data["s"]
@@ -398,9 +403,79 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
]
tt = transform(self.encode_mouse_path(mouse_path, self.c, self.s), self.c, self.s)
rp = self.MD5(self.gt + self.challenge + self.s)
temp1 = '''"lang":"zh-cn","type":"fullpage","tt":"%s","light":"DIV_0","s":"c7c3e21112fe4f741921cb3e4ff9f7cb","h":"321f9af1e098233dbd03f250fd2b5e21","hh":"39bd9cad9e425c3a8f51610fd506e3b3","hi":"09eb21b3ae9542a9bc1e8b63b3d9a467","vip_order":-1,"ct":-1,"ep":{"v":"9.1.9-dbjg5z","te":false,"me":true,"ven":"Google Inc. (Intel)","ren":"ANGLE (Intel, Intel(R) Iris(R) Xe Graphics (0x0000A7A0) Direct3D11 vs_5_0 ps_5_0, D3D11)","fp":["scroll",0,1602,1724571628498,null],"lp":["up",386,217,1724571629854,"pointerup"],"em":{"ph":0,"cp":0,"ek":"11","wd":1,"nt":0,"si":0,"sc":0},"tm":{"a":1724571567311,"b":1724571567549,"c":1724571567562,"d":0,"e":0,"f":1724571567312,"g":1724571567312,"h":1724571567312,"i":1724571567317,"j":1724571567423,"k":1724571567330,"l":1724571567423,"m":1724571567545,"n":1724571567547,"o":1724571567569,"p":1724571568259,"q":1724571568259,"r":1724571568261,"s":1724571570378,"t":1724571570378,"u":1724571570380},"dnf":"dnf","by":0},"passtime":1600,"rp":"%s",''' % (
tt, rp)
r = "{" + temp1 + '"captcha_token":"1198034057","du6o":"eyjf7nne"}'
payload_dict = {
"lang": "zh-cn",
"type": "fullpage",
"tt": tt, # 使用变量
"light": "DIV_0",
"s": "c7c3e21112fe4f741921cb3e4ff9f7cb",
"h": "321f9af1e098233dbd03f250fd2b5e21",
"hh": "39bd9cad9e425c3a8f51610fd506e3b3",
"hi": "09eb21b3ae9542a9bc1e8b63b3d9a467",
"vip_order": -1,
"ct": -1,
"ep": {
"v": "9.2.0-guwyxh",
"te": False, # JSON 'false' 对应 Python 'False'
"me": True, # JSON 'true' 对应 Python 'True'
"ven": "Google Inc. (Intel)",
"ren": "ANGLE (Intel, Intel(R) Iris(R) Xe Graphics (0x0000A7A0) Direct3D11 vs_5_0 ps_5_0, D3D11)",
"fp": [
"scroll",
0,
1602,
1724571628498,
None # JSON 'null' 对应 Python 'None'
],
"lp": [
"up",
386,
217,
1724571629854,
"pointerup"
],
"em": {
"ph": 0,
"cp": 0,
"ek": "11",
"wd": 1,
"nt": 0,
"si": 0,
"sc": 0
},
"tm": {
"a": 1724571567311,
"b": 1724571567549,
"c": 1724571567562,
"d": 0,
"e": 0,
"f": 1724571567312,
"g": 1724571567312,
"h": 1724571567312,
"i": 1724571567317,
"j": 1724571567423,
"k": 1724571567330,
"l": 1724571567423,
"m": 1724571567545,
"n": 1724571567547,
"o": 1724571567569,
"p": 1724571568259,
"q": 1724571568259,
"r": 1724571568261,
"s": 1724571570378,
"t": 1724571570378,
"u": 1724571570380
},
"dnf": "dnf",
"by": 0
},
"passtime": 1600,
"rp": rp, # 使用变量
"captcha_token":"2064329542",
"tsfq":"xovrayel"
}
# r = "{" + temp1 + '"captcha_token":"1198034057","du6o":"eyjf7nne"}'
r = json.dumps(payload_dict, separators=(',', ':'))
ct = self.aes_encrypt(r)
s = [byte for byte in ct]
w = self.encode(s)
@@ -414,6 +489,7 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
"w": w
}
resp = self.session.get("https://api.geetest.com/ajax.php", params=params).text
logger.debug(f"ajax结果:{resp}")
return json.loads(resp[22:-1])["data"]
def get_pic(self):
@@ -435,11 +511,15 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
}
resp = self.session.get("https://api.geevisit.com/get.php", params=params).text
data = json.loads(resp[22:-1])["data"]
logger.debug(f"获取图片结果{data}")
self.pic_path = data["pic"]
pic_url = "https://" + data["resource_servers"][0][:-1] + data["pic"]
pic_data = self.session.get(pic_url).content
pic_name = data["pic"].split("/")[-1]
pic_type = data["pic_type"]
if "80/2023-12-04T12/icon" in pic_url:
pic_type = "icon1"
with open(PATH.join(validate_path,pic_name),'wb+') as f:
f.write(pic_data)
return pic_data,pic_name,pic_type
@@ -483,4 +563,4 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
"w": w
}
resp = self.session.get("https://api.geevisit.com/ajax.php", params=params).text
return resp[1:-1]
return resp[1:-1]

147
main.py
View File

@@ -4,45 +4,86 @@ import time
import httpx
import shutil
import random
import uvicorn
import logging
import logging.config
from crack import Crack
from typing import Optional, Dict, Any
from contextlib import asynccontextmanager
from fastapi.responses import JSONResponse
from fastapi import FastAPI,Query, HTTPException
from predict import predict_onnx,predict_onnx_pdl,predict_onnx_dfine
from crop_image import crop_image_v3,save_path,save_fail_path,save_pass_path,validate_path
PORT = 9645
platform = os.name
import logging.handlers
LOG_DIR = "logs"
os.makedirs(LOG_DIR, exist_ok=True)
LOG_FILENAME = os.path.join(LOG_DIR, "app.log")
log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
# --- 日志配置字典 ---
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
# 定义一个名为 "default" 的格式化器
"default": {
"()": f"uvicorn.{'_logging' if platform== 'nt' else 'logging'}.DefaultFormatter",
"fmt": "%(levelprefix)s %(asctime)s | %(message)s",
"()": "logging.Formatter", # 使用标准的 Formatter
"fmt": "[%(levelname)s - %(filename)s:%(lineno)d | %(funcName)s] - %(asctime)s - %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S",
},
},
"handlers": {
"default": {
"formatter": "default",
# 控制台输出的 Handler
"console": {
"class": "logging.StreamHandler",
"formatter": "default", # 使用上面定义的 default 格式
"stream": "ext://sys.stderr",
},
# 文件输出和轮转的 Handler
"file_rotating": {
"class": "logging.handlers.TimedRotatingFileHandler",
"formatter": "default", # 也使用上面定义的 default 格式
"filename": LOG_FILENAME, # 日志文件路径
"when": "D", # 按天轮转 ('D' for Day)
"interval": 1, # 每天轮转一次
"backupCount": 2, # 保留2个旧的日志文件 (加上当前文件总共覆盖3天)
"encoding": "utf-8",
},
},
"loggers": {
# 根日志记录器的级别设置为 INFO
"": {"handlers": ["default"], "level": "INFO"},
"uvicorn.error": {"level": "INFO"},
"uvicorn.access": {"handlers": ["default"], "level": "INFO", "propagate": False},
# 根日志记录器
"": {
# 同时将日志发送到 console 和 file_rotating 两个 Handler
"handlers": ["console", "file_rotating"],
"level": log_level,
},
# 针对 uvicorn 的日志记录器进行配置,确保它们也使用我们的设置
"uvicorn": {
"handlers": ["console", "file_rotating"],
"level": "WARNING",
"propagate": False, # 阻止 uvicorn 日志向上传播到根 logger避免重复记录
},
"uvicorn.error": {
"level": "WARNING",
"propagate": True, # uvicorn.error 应该传播,以便根记录器可以捕获它
},
"uvicorn.access": {
"handlers": ["console", "file_rotating"],
"level": log_level,
"propagate": False,
},
},
}
logging.config.dictConfig(LOGGING_CONFIG)
logger = logging.getLogger(__name__)
from crack import Crack
from typing import Optional, Dict, Any
from contextlib import asynccontextmanager
from fastapi.responses import JSONResponse
from fastapi import FastAPI,Query, HTTPException
from predict import (predict_onnx,
predict_onnx_pdl,
predict_onnx_dfine,
predict_dino_classify_pipeline,
load_by,
unload,
get_models,
get_available_models)
from crop_image import crop_image_v3,save_path,save_fail_path,save_pass_path,validate_path
PORT = 9645
def get_available_hosts() -> set[str]:
"""获取本机所有可用的IPv4地址。"""
import socket
@@ -70,7 +111,7 @@ async def lifespan(app: FastAPI):
port = server.config.port if server else PORT
if host == "0.0.0.0":
available_hosts = get_available_hosts()
logger.info(f"服务地址(依需求选用docker中使用宿主机host:{port}):")
logger.info(f"服务地址(依需求选用docker中使用宿主机host:{port}若使用Uvicorn运行则基于命令):")
for h in sorted(list(available_hosts)):
logger.info(f" - http://{h}:{port}")
else:
@@ -78,6 +119,7 @@ async def lifespan(app: FastAPI):
logger.info(f"可用服务路径如下:")
for route in app.routes:
logger.info(f" -{route.methods} {route.path}")
logger.info(f"具体api使用可以查看/docs")
logger.info("="*50)
yield
@@ -89,9 +131,9 @@ app = FastAPI(title="极验V3图标点选+九宫格", lifespan=lifespan)
def prepare(gt: str, challenge: str) -> tuple[Crack, bytes, str, str]:
"""获取信息。"""
logging.info(f"开始获取:\ngt:{gt}\nchallenge:{challenge}")
logger.info(f"开始获取:\ngt:{gt}\nchallenge:{challenge}")
crack = Crack(gt, challenge)
crack.gettype()
logger.debug(f"初次获得{crack.gettype()}")
crack.get_c_s()
time.sleep(random.uniform(0.4,0.6))
crack.ajax()
@@ -114,7 +156,12 @@ def do_pass_nine(pic_content: bytes, use_v3_model: bool, point: Optional[str]) -
def do_pass_icon(pic:Any, draw_result: bool) -> list[str]:
"""处理图标点选验证码,返回坐标点列表。"""
result_list = predict_onnx_dfine(pic,draw_result)
print(result_list)
logger.debug(result_list)
return [f"{round(x / 333 * 10000)}_{round(y / 333 * 10000)}" for x, y in result_list]
def do_pass_icon0(pic:Any, draw_result: bool) -> list[str]:
"""处理图标点选验证码,返回坐标点列表。"""
result_list = predict_dino_classify_pipeline(pic,draw_result)
return [f"{round(x / 333 * 10000)}_{round(y / 333 * 10000)}" for x, y in result_list]
def save_image_for_train(pic_name,pic_type,passed):
@@ -141,14 +188,16 @@ def handle_pass_request(gt: str, challenge: str, save_result: bool, **kwargs) ->
crack, pic_content, pic_name, pic_type = prepare(gt, challenge)
# 2. 识别
logger.debug(f"接收图片类型{pic_type}")
if pic_type == "nine":
point_list = do_pass_nine(
pic_content,
use_v3_model=kwargs.get("use_v3_model", True),
point=kwargs.get("point",None)
)
elif pic_type == "icon":
elif pic_type == 'icon': # dino
point_list = do_pass_icon0(pic_content, save_result)
elif pic_type == "icon1": # d-fine
point_list = do_pass_icon(pic_content, save_result)
else:
raise HTTPException(status_code=400, detail=f"Unknown picture type: {pic_type}")
@@ -169,14 +218,14 @@ def handle_pass_request(gt: str, challenge: str, save_result: bool, **kwargs) ->
os.remove(os.path.join(validate_path,pic_name))
total_time = time.monotonic() - start_time
logging.info(
logger.info(
f"请求完成,耗时: {total_time:.2f}s (等待 {wait_time:.2f}s). "
f"结果: {result}"
)
return JSONResponse(content=result)
except Exception as e:
logging.error(f"服务错误: {e}", exc_info=True)
logger.error(f"服务错误: {e}", exc_info=True)
return JSONResponse(
status_code=500,
content={"error": "An internal server error occurred.", "detail": str(e)}
@@ -227,12 +276,50 @@ def pass_hutao(gt: str = Query(...),
return JSONResponse(content=rebuild_content, status_code=original_status_code)
except Exception as e:
logging.error(f"修改路由错误: {e}", exc_info=True)
logger.error(f"修改路由错误: {e}", exc_info=True)
return JSONResponse(
status_code=500,
content={"error": "An internal server error occurred.", "detail": str(e)}
)
@app.get("/list_model")
def list_model():
return JSONResponse(content = get_models())
@app.get("/list_all_model")
def list_model():
return JSONResponse(content = get_available_models())
@app.get("/load_model")
def load_model(name: str = Query(...)):
return JSONResponse(content = get_models())
@app.get("/unload_model")
def unload_model(name: str = Query(...)):
return JSONResponse(content = get_models())
@app.get("/set_log_level")
def set_log_level(level: str = Query(...)):
"""
在服务运行时动态修改所有主要 logger 的日志级别。
例如: /set_log_level?level=DEBUG
"""
# 将字符串级别转换为 logging 模块对应的整数值
level_str = str(level).upper()
numeric_level = getattr(logging, level_str, None)
if not isinstance(numeric_level, int):
raise HTTPException(status_code=400, detail=f"无效的日志级别: {level}")
# 获取并修改您配置中所有关键 logger 的级别
# 1. 修改根 logger
logging.getLogger().setLevel(numeric_level)
# 2. 修改 uvicorn logger
logging.getLogger("uvicorn").setLevel(numeric_level)
logging.getLogger("uvicorn.error").setLevel(numeric_level)
logging.getLogger("uvicorn.access").setLevel(numeric_level)
# 记录一条高级别的日志,确保能被看到
logger.warning(f"所有 logger 的日志级别已被动态修改为: {level_str}")
return JSONResponse(content = f"Log level successfully set to {level_str}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app,port=PORT)
uvicorn.run(app,port=PORT,log_config=LOGGING_CONFIG)

View File

@@ -1,12 +1,40 @@
import os
import numpy as np
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image,bytes_to_pil,validate_path
import time
from PIL import Image, ImageDraw
import base64
import random
import logging
import numpy as np
from io import BytesIO
import onnxruntime as ort
from PIL import Image, ImageDraw
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image,bytes_to_pil,validate_path,save_path
logger = logging.getLogger(__name__)
def safe_load_img(image):
im_pil = None
try:
if isinstance(image, Image.Image):
im_pil = image
elif isinstance(image, str):
try:
im_pil = Image.open(image)
except (IOError, FileNotFoundError):
if ',' in image:
image = image.split(',')[-1]
padding = len(image) % 4
if padding > 0:
image += '=' * (4 - padding)
img_bytes = base64.b64decode(image)
im_pil = Image.open(io.BytesIO(img_bytes))
elif isinstance(image, bytes):
im_pil = bytes_to_pil(image)
elif isinstance(image, np.ndarray):
im_pil = Image.fromarray(image)
else:
raise ValueError(f"不支持的输入类型: {type(image)}")
return im_pil.convert("RGB")
except Exception as e:
raise ValueError(f"无法加载或解析图像,错误: {e}")
def predict(icon_image, bg_image):
import torch
from train import MyResNet18, data_transform
@@ -34,7 +62,7 @@ def predict(icon_image, bg_image):
model = MyResNet18(num_classes=91) # 这里的类别数要与训练时一致
model.load_state_dict(torch.load(model_path))
model.eval()
print("加载模型,耗时:", time.time() - start)
logger.info("加载模型,耗时:", time.time() - start)
start = time.time()
target_images = torch.stack(target_images, dim=0)
@@ -52,14 +80,14 @@ def predict(icon_image, bg_image):
)
scores.append(similarity.cpu().item())
# 从左到右,从上到下,依次为每张图片的置信度
print(scores)
logger.info(scores)
# 对数组进行排序,保持下标
indexed_arr = list(enumerate(scores))
sorted_arr = sorted(indexed_arr, key=lambda x: x[1], reverse=True)
# 提取最大三个数及其下标
largest_three = sorted_arr[:3]
print(largest_three)
print("识别完成,耗时:", time.time() - start)
logger.info(largest_three)
logger.info("识别完成,耗时:", time.time() - start)
def load_model(name='PP-HGNetV2-B4.onnx'):
# 加载onnx模型
@@ -69,7 +97,7 @@ def load_model(name='PP-HGNetV2-B4.onnx'):
model_path = os.path.join(current_dir, 'model', name)
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
print(f"加载{name}模型,耗时:{time.time() - start}")
logger.info(f"加载{name}模型,耗时:{time.time() - start}")
def load_dfine_model(name='d-fine-n.onnx'):
# 加载onnx模型
@@ -78,8 +106,31 @@ def load_dfine_model(name='d-fine-n.onnx'):
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', name)
session_dfine = ort.InferenceSession(model_path)
print(f"加载{name}模型,耗时:{time.time() - start}")
logger.info(f"加载{name}模型,耗时:{time.time() - start}")
def load_yolo11n(name='yolo11n.onnx'):
global session_yolo11n
start = time.time()
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', name)
session_yolo11n = ort.InferenceSession(model_path)
logger.info(f"加载{name}模型,耗时:{time.time() - start}")
def load_dinov3(name='dinov3-small.onnx'):
global session_dino3
start = time.time()
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', name)
session_dino3 = ort.InferenceSession(model_path)
logger.info(f"加载{name}模型,耗时:{time.time() - start}")
def load_dino_classify(name='atten.onnx'):
global session_dino_cf
start = time.time()
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', name)
session_dino_cf = ort.InferenceSession(model_path)
logger.info(f"加载{name}模型,耗时:{time.time() - start}")
def predict_onnx(icon_image, bg_image, point = None):
import cv2
@@ -137,8 +188,7 @@ def predict_onnx(icon_image, bg_image, point = None):
else:
similarity = cosine_similarity(target_output, out_put)
scores.append(similarity)
# 从左到右,从上到下,依次为每张图片的置信度
# print(scores)
logger.debug(f"从左到右,从上到下,依次为每张图片的置信度:\n{scores}")
# 对数组进行排序,保持下标
indexed_arr = list(enumerate(scores))
sorted_arr = sorted(indexed_arr, key=lambda x: x[1], reverse=True)
@@ -149,7 +199,7 @@ def predict_onnx(icon_image, bg_image, point = None):
# 基于分数判断
else:
answer = [one[0] for one in sorted_arr if one[1] > point]
print(f"识别完成{answer},耗时: {time.time() - start}")
logger.info(f"识别完成{answer},耗时: {time.time() - start}")
#draw_points_on_image(bg_image, answer)
return answer
@@ -202,40 +252,79 @@ def predict_onnx_pdl(images_path):
if len(answer) == 0:
all_sort =[np.argsort(one) for one in outputs]
answer = [coordinates[index] for index in range(9) if all_sort[index][1] == target]
print(f"识别完成{answer},耗时: {time.time() - start}")
logger.info(f"识别完成{answer},耗时: {time.time() - start}")
with open(os.path.join(images_path,"nine.jpg"),'rb') as f:
bg_image = f.read()
# draw_points_on_image(bg_image, answer)
#draw_points_on_image(bg_image, answer)
return answer
# d-fine的推理代码及函数
def calculate_iou(boxA, boxB):
"""
使用 NumPy 计算两个边界框的交并比 (IoU)。
"""
# 确定相交矩形的坐标
xA = np.maximum(boxA[0], boxB[0])
yA = np.maximum(boxA[1], boxB[1])
xB = np.minimum(boxA[2], boxB[2])
yB = np.minimum(boxA[3], boxB[3])
# 计算相交区域的面积
intersection_area = np.maximum(0, xB - xA) * np.maximum(0, yB - yA)
# 计算两个边界框的面积
boxA_area = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
boxB_area = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
# 计算并集面积
union_area = float(boxA_area + boxB_area - intersection_area)
# 计算 IoU
if union_area == 0:
return 0.0
return 0.0
iou = intersection_area / union_area
return iou
def non_maximum_suppression(detections, iou_threshold=0.35):
"""
对检测结果执行非极大值抑制 (NMS)。
参数:
detections -- 一个列表,其中每个元素是包含 'box', 'score' 的字典。
例如: [{'box': [x1, y1, x2, y2], 'score': 0.9, ...}, ...]
iou_threshold -- 一个浮点数,用于判断框是否重叠的 IoU 阈值。
返回:
final_detections -- 经过 NMS 处理后保留下来的检测结果列表。
"""
# 1. 检查检测结果是否为空
if not detections:
return []
# 2. 按置信度score从高到低对边界框进行排序
# 我们使用 lambda 函数来指定排序的键
detections.sort(key=lambda x: x['score'], reverse=True)
final_detections = []
# 3. 循环处理,直到没有检测结果为止
while detections:
# 4. 将当前得分最高的检测结果(第一个)添加到最终列表中
# 并将其从原始列表中移除
best_detection = detections.pop(0)
final_detections.append(best_detection)
# 5. 计算刚刚取出的最佳框与剩余所有框的 IoU
# 并只保留那些 IoU 小于阈值的框
detections_to_keep = []
for det in detections:
# 假设相同类别的才进行NMS
iou = calculate_iou(best_detection['box'], det['box'])
if iou < iou_threshold:
detections_to_keep.append(det)
# 用筛选后的列表替换原始列表,进行下一轮迭代
detections = detections_to_keep
return final_detections
@@ -246,10 +335,7 @@ def predict_onnx_dfine(image,draw_result=False):
image_input_name = input_nodes[0].name
size_input_name = input_nodes[1].name
output_names = [node.name for node in output_nodes]
if isinstance(image,bytes):
im_pil = bytes_to_pil(image)
else:
im_pil = Image.open(image).convert("RGB")
im_pil = safe_load_img(image)
w, h = im_pil.size
orig_size_np = np.array([[w, h]], dtype=np.int64)
im_resized = im_pil.resize((320, 320), Image.Resampling.BILINEAR)
@@ -340,18 +426,335 @@ def predict_onnx_dfine(image,draw_result=False):
text_position = (box[0] + 2, box[1] - 12 if box[1] > 12 else box[1] + 2)
draw.text(text_position, text, fill=color)
if draw_result:
save_path = os.path.join(validate_path,"icon_result.jpg")
im_pil.save(save_path)
print(f"图片可视化结果保存在{save_path}")
print(f"图片顺序的中心点{points}")
save_path_temp = os.path.join(validate_path,"icon_result.jpg")
im_pil.save(save_path_temp)
logger.info(f"图片可视化结果暂时保存在{save_path_temp},运行完成后移至{save_path}")
logger.info(f"图片顺序的中心点{points}")
return points
print(f"使用推理设备: {ort.get_device()}")
# yolo的推理代码及函数
def predict_onnx_yolo(image):
def filter_Detections(results, thresh = 0.5):
results = results[0]
results = results.transpose()
# if model is trained on 1 class only
if len(results[0]) == 5:
# filter out the detections with confidence > thresh
considerable_detections = [detection for detection in results if detection[4] > thresh]
considerable_detections = np.array(considerable_detections)
return considerable_detections
# if model is trained on multiple classes
else:
A = []
for detection in results:
class_id = detection[4:].argmax()
confidence_score = detection[4:].max()
new_detection = np.append(detection[:4],[class_id,confidence_score])
A.append(new_detection)
A = np.array(A)
# filter out the detections with confidence > thresh
considerable_detections = [detection for detection in A if detection[-1] > thresh]
considerable_detections = np.array(considerable_detections)
return considerable_detections
def NMS(boxes, conf_scores, iou_thresh = 0.55):
# boxes [[x1,y1, x2,y2], [x1,y1, x2,y2], ...]
x1 = boxes[:,0]
y1 = boxes[:,1]
x2 = boxes[:,2]
y2 = boxes[:,3]
areas = (x2-x1)*(y2-y1)
order = conf_scores.argsort()
keep = []
keep_confidences = []
while len(order) > 0:
idx = order[-1]
A = boxes[idx]
conf = conf_scores[idx]
order = order[:-1]
xx1 = np.take(x1, indices= order)
yy1 = np.take(y1, indices= order)
xx2 = np.take(x2, indices= order)
yy2 = np.take(y2, indices= order)
keep.append(A)
keep_confidences.append(conf)
# iou = inter/union
xx1 = np.maximum(x1[idx], xx1)
yy1 = np.maximum(y1[idx], yy1)
xx2 = np.minimum(x2[idx], xx2)
yy2 = np.minimum(y2[idx], yy2)
w = np.maximum(xx2-xx1, 0)
h = np.maximum(yy2-yy1, 0)
intersection = w*h
# union = areaA + other_areas - intesection
other_areas = np.take(areas, indices= order)
union = areas[idx] + other_areas - intersection
iou = intersection/union
boleans = iou < iou_thresh
order = order[boleans]
# order = [2,0,1] boleans = [True, False, True]
# order = [2,1]
return keep, keep_confidences
def rescale_back(results,img_w,img_h,imgsz=384):
cx, cy, w, h, class_id, confidence = results[:,0], results[:,1], results[:,2], results[:,3], results[:, 4], results[:,-1]
cx = cx/imgsz * img_w
cy = cy/imgsz * img_h
w = w/imgsz * img_w
h = h/imgsz * img_h
x1 = cx - w/2
y1 = cy - h/2
x2 = cx + w/2
y2 = cy + h/2
boxes = np.column_stack((x1, y1, x2, y2, class_id))
keep, keep_confidences = NMS(boxes,confidence)
return keep, keep_confidences
im_pil = safe_load_img(image)
im_resized = im_pil.resize((384, 384), Image.Resampling.BILINEAR)
im_data = np.array(im_resized, dtype=np.float32) / 255.0
im_data = im_data.transpose(2, 0, 1)
im_data = np.expand_dims(im_data, axis=0)
res = session_yolo11n.run(None,{"images":im_data})
results = filter_Detections(res)
rescaled_results, confidences = rescale_back(results,im_pil.size[0],im_pil.size[1])
images = {"top":[],"bottom":[]}
for r, conf in zip(rescaled_results, confidences):
x1,y1,x2,y2, cls_id = r
cls_id = int(cls_id)
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
cropped_image = im_pil.crop((x1, y1, x2, y2))
if cls_id == 0:
images['top'].append({"image":cropped_image,"bbox":[x1, y1, x2, y2]})
else:
images['bottom'].append({"image":cropped_image,"bbox":[x1, y1, x2, y2]})
return images
# dinov3的推理代码及函数
def make_lvd_transform(resize_size: int = 224):
"""
返回一个图像预处理函数功能与PyTorch版本相同
"""
# 定义标准化参数
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
def transform(image) -> np.ndarray:
"""
图像预处理转换
Args:
image: PIL Image 或 numpy array (H,W,C) 范围[0,255]
Returns:
numpy array (C,H,W) 标准化后的float32数组
"""
# 确保输入是PIL图像
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'))
# 1. 调整大小 (使用LANCZOS抗锯齿对应antialias=True)
image = image.resize((resize_size, resize_size), Image.LANCZOS)
# 2. 转换为numpy数组并调整数据类型和范围
# PIL图像转换为numpy数组 (H,W,C) 范围[0,255]
image_array = np.array(image, dtype=np.float32)
# 如果图像是RGBA只取RGB通道
if image_array.shape[-1] == 4:
image_array = image_array[:, :, :3]
# 缩放到[0,1]范围 (对应scale=True)
image_array /= 255.0
# 3. 标准化
# 注意PyTorch的Normalize是逐通道进行的
image_array = (image_array - mean) / std
# 4. 转换维度从 (H,W,C) 到 (C,H,W) - 与PyTorch张量布局一致
image_array = np.transpose(image_array, (2, 0, 1))
return image_array
return transform
transform = make_lvd_transform(224)
def predict_onnx_dino(image):
im_pil = safe_load_img(image)
input_name_model = session_dino3.get_inputs()[0].name
output_name_model = session_dino3.get_outputs()[0].name
return session_dino3.run([output_name_model],
{input_name_model:
np.expand_dims(transform(im_pil), axis=0).astype(np.float32)
}
)[0]
# dinov3结果分类的推理代码及函数
def predict_dino_classify(tokens1,tokens2):
patch_tokens1 = tokens1[:, 5:, :]
patch_tokens2 = tokens2[:, 5:, :]
input_name_model =session_dino_cf.get_inputs()[0].name
output_name_model =session_dino_cf.get_outputs()[0].name
emb1 = session_dino_cf.run([output_name_model], {input_name_model: patch_tokens1})[0]
emb2 = session_dino_cf.run([output_name_model], {input_name_model: patch_tokens2})[0]
emb1_flat = emb1.flatten()
emb2_flat = emb2.flatten()
return float(np.dot(emb1_flat, emb2_flat) / (np.linalg.norm(emb1_flat) * np.linalg.norm(emb2_flat)))
def predict_dino_classify_pipeline(image,draw_result=False):
im_pil = safe_load_img(image)
if draw_result:
draw = ImageDraw.Draw(im_pil)
crops = predict_onnx_yolo(im_pil)
features = {}
for k in crops:
features.update({k:[]})
for v in crops[k]:
features[k].append({"feature":predict_onnx_dino(v['image']),"bbox":v['bbox']})
features["bottom"] = sorted(features["bottom"], key=lambda x: x["bbox"][0])
used_indices = set()
sequence = []
for target in features['bottom']:
available = [(idx, opt) for idx, opt in enumerate(features['top']) if idx not in used_indices]
if not available:
break
if len(available) == 1:
best_idx, best_opt = available[0]
else:
best_idx, best_opt = max(
available,
key=lambda item: predict_dino_classify(target['feature'], item[1]['feature'])
)
sequence.append(best_opt['bbox'])
used_indices.add(best_idx)
colors = ["red", "blue", "green", "yellow", "white", "purple", "orange"]
points = []
for id,one in enumerate(sequence):
center_x = (one[0] + one[2]) / 2
center_y = (one[1] + one[3]) / 2
w = abs(one[0] - one[2])
y = abs(one[1] - one[3])
points.append((center_x+random.randint(int(-w/5),int(w/5)),
center_y+random.randint(int(-y/5),int(y/5))
))
if draw_result:
draw.rectangle(one, outline=colors[id], width=1)
text = f"{id+1}"
text_position = (center_x, center_y)
draw.text(text_position, text, fill='white')
if draw_result:
save_path_temp = os.path.join(validate_path,"icon_result.jpg")
im_pil.save(save_path_temp)
logger.info(f"图片可视化结果暂时保存在{save_path_temp},运行完成后移至{save_path}")
return points
logger.info(f"使用推理设备: {ort.get_device()}")
def use_pdl():
load_model()
def use_dfine():
load_dfine_model()
def use_multi():
load_yolo11n()
load_dinov3()
load_dino_classify()
model_for = [
{"loader":use_pdl,
"include":["session"],
"support":['paddle','pdl','nine','原神','genshin']
},
{"loader":use_dfine,
"include":["session_dfine"],
"support":['dfine','click','memo','note','‌便笺‌']
},
{"loader":use_multi,
"include":['session_yolo11n', 'session_dino3', 'session_dino_cf'],
"support":['multi','dino','click2','星穹铁道','崩铁','绝区零','zzz','hkrpg']
}
]
def get_models():
res = ["以下是当前加载的模型及其对应关键字"]
for key,value in globals().items():
if key.startswith("session") and value is not None:
for one in model_for:
if key in one['include']:
res.append(f" -{key},关键词:{one['support']}")
return res
def get_available_models():
res = ["以下是所有可用模型及其对应关键字"]
for one in model_for:
res.append(f" -{one['include']}关键词:{one['support']}")
return res
def load_by(name):
for one in model_for:
if name in one['support'] or name in one['include']:
one['loader']()
return get_models()
logger.error(f"不支持的名称,可以使用‌便笺‌、原神、崩铁、绝区零表示")
def unload(*names, safe_mode=True):
import gc
protected_vars = {'__name__', '__file__', '__builtins__',
'unload'}
for name in names:
if name in globals():
if safe_mode and name in protected_vars:
logger.error(f"警告: 跳过保护变量 '{name}'")
continue
if not name.startswith('session'):
logger.info("删除的不是模型!")
var = globals()[name]
if hasattr(var, 'close'):
try:
var.close()
except:
pass
globals()[name] = None
logger.info(f"已释放变量: {name}")
collected = gc.collect()
logger.info(f"垃圾回收器清理了 {collected} 个对象")
return get_models()
if int(os.environ.get("use_pdl",1)):
load_model()
use_pdl()
if int(os.environ.get("use_dfine",1)):
load_dfine_model()
use_dfine()
if int(os.environ.get("use_multi",1)):
use_multi()
if __name__ == "__main__":
# 使用resnet18.onnx
# load_model("resnet18.onnx")
@@ -368,5 +771,7 @@ if __name__ == "__main__":
# 使用PP-HGNetV2-B4.onnx
#predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf')
predict_onnx_dfine(r"n:\爬点选\dataset\3f98ff0c91dd4882a8a24d451283ad96.jpg",True)
print(predict_onnx_dfine(r"f:\项目留档\JPEGImages\8bdee494b00d401aae3f496e76d886fc.jpg",True))
# use_multi()
# print(predict_dino_classify_pipeline("0a92e85f89b345279e74deaa9afa9e1c.jpg",True))