mirror of
https://github.com/luguoyixiazi/test_nine.git
synced 2025-12-06 14:52:49 +08:00
Compare commits
4 Commits
c26e354c2f
...
a6a53bb46b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a6a53bb46b | ||
|
|
e98e47a0af | ||
|
|
84d225dc14 | ||
|
|
b4d35b9906 |
55
README.md
55
README.md
@@ -8,19 +8,28 @@
|
||||
|
||||
## 参考项目
|
||||
|
||||
模型及V4数据集:https://github.com/taisuii/ClassificationCaptchaOcr
|
||||
resnet模型及V4数据集:https://github.com/taisuii/ClassificationCaptchaOcr
|
||||
|
||||
点选检测模型https://github.com/Peterande/D-FINE
|
||||
|
||||
另一个点选检测模型https://github.com/facebookresearch/dinov3
|
||||
|
||||
另一个点选的检测模型https://github.com/ultralytics/ultralytics
|
||||
|
||||
api:https://github.com/ravizhan/geetest-v3-click-crack
|
||||
|
||||
感谢 @kissnavel 提供的关于我没玩的那两款的api
|
||||
|
||||
## 运行步骤
|
||||
|
||||
### 1.安装依赖(本地必选,使用docker跳至[5-b](#docker))
|
||||
|
||||
(可选)a.如果要训练paddle的话还得安装paddlex及图像分类模块,安装看项目https://github.com/PaddlePaddle/PaddleX
|
||||
(可选)b.d-fine训练看项目https://github.com/Peterande/D-FINE
|
||||
(* 必选!)c.模型需要在项目目录下新建一个model文件夹,然后把模型文件放进去,具体命名可以是resnet18.onnx或者PP-HGNetV2-B4.onnx,默认使用PP-HGNetV2-B4模型,如果用resnet则use_v3_model设置为False,因为模型的输入输出不一样,可以自行修改,d-fine的模型同样放置此路径用于通过点选
|
||||
(可选)a-0. 如果要训练paddle的话还得安装paddlex及图像分类模块,安装看项目[https://github.com/PaddlePaddle/PaddleX](https://github.com/PaddlePaddle/PaddleX)
|
||||
(可选)a-1. d-fine训练看项目[https://github.com/Peterande/D-FINE](https://github.com/Peterande/D-FINE)
|
||||
(可选)a-2. dinov3使用看项目[https://github.com/Peterande/D-FINE](https://github.com/facebookresearch/dinov3) 具体分类基于patch token
|
||||
(可选)a-3. yolo训练看项目[https://github.com/Peterande/D-FINE](https://github.com/ultralytics/ultralytics)
|
||||
(可选)a-4. 基于dinov3的分类,架构极其简陋,从一组图片中抽取所有已定位的图片的特征图,而后在一组中以下侧作为锚点,设置1正面样本,其余均做负面样本,学习dinov3提取特征中影响相似度比对的维度
|
||||
(* 必选!)b.模型需要在项目目录下新建一个model文件夹,然后把模型文件放进去,具体命名可以是resnet18.onnx或者PP-HGNetV2-B4.onnx,默认使用PP-HGNetV2-B4模型,如果用resnet则use_v3_model设置为False,因为模型的输入输出不一样,可以自行修改,d-fine的模型同样放置此路径用于通过点选,另外的dinov3和yolo1n和我自己做的任务头整体组合为一个流水线作业,全都要放进去
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
@@ -38,7 +47,7 @@ pip install -r requirements_without_train.txt
|
||||
- 数据集详情参考上面标注的项目,但是上面项目是V4数据集,V3没有demo,自行发挥吧,用V4练V3不改代码正确率有点感人
|
||||
- 主要是V4的尺寸和V3有差别,V4的api直接给两张图,一张是目标图,一张是九宫格,V3放在一起要切目标,且V3目标图清晰度很低,V4九宫格切了之后是100 * 86的图(去掉黑边),但是V3九宫格切的是112 * 112,不确定V4九宫格内容在V3基础上做了什么变换,反正改预处理就完事了
|
||||
|
||||
##### b. 训练PP-HGNetV2-B4(可选)
|
||||
##### a-0. 训练PP-HGNetV2-B4(可选)
|
||||
|
||||
在paddle上随便找的,数据集格式如下,如果拿V4练V3,建议是多整点变换
|
||||
|
||||
@@ -50,17 +59,22 @@ pip install -r requirements_without_train.txt
|
||||
└─验证集和测试集同上
|
||||
```
|
||||
|
||||
##### b. 训练d-fine(可选)
|
||||
##### a-1. 训练d-fine(可选)
|
||||
|
||||
数据集格式如d-fine中标识,如果不修改源码则num_classes需+1,采用coco格式即可,我用的320*320,dataloader注释掉了RandomZoomOut、RandomHorizontalFlip、RandomIoUCrop(这些我全写在数据集生成中了)
|
||||
|
||||
##### c. 如果要切V3的九宫格图用crop_image.py的crop_image_v3,切V4则使用crop_image,自行编写切图脚本
|
||||
##### a-2. 训练基于dinov3和yolo的流水线(可选)
|
||||
|
||||
数据集格式我自己拟定的,因为一次加载使用一个锚点一个正面多个负面,所以我先组织了格式,当然原始标注时还是用的coco,将每个image上的所有标注全组合在两个列表,根据y轴划分top和bottom,以bottom为锚点(实际上没差,但是实际使用不涉及top组内比对就跳过了节省资源),top中1✅多❌
|
||||
|
||||
##### b. 如果要切V3的九宫格图用crop_image.py的crop_image_v3,切V4则使用crop_image,自行编写切图脚本
|
||||
|
||||
### 3.训练模型(可选)
|
||||
|
||||
- 训练resnet18运行 `python train.py`
|
||||
- 如果训练PP-HGNetV2-B4运行`python train_paddle.py`
|
||||
- 训练d-fine参照原项目,一个模型拿下,比ddddocr+相似性检测资源开销小点
|
||||
- 训练代码极其简单,可以基于我提出的思路即数据集组织,使用llm直接生成,yolo训练参考其仓库即可,dinov3没训练,也没拿来做定位,因为定位头的规模大于yolo11n,不划算
|
||||
|
||||
### 4-a.PP-HGNetV2-B4模型和resnet模型转换为onnx(可选)
|
||||
|
||||
@@ -71,18 +85,30 @@ pip install -r requirements_without_train.txt
|
||||
|
||||
- 依原项目转换
|
||||
- 推理时图像预处理应于训练时一致,d-fine仓库中onnx推理的预处理和训练不一致……
|
||||
|
||||
|
||||
### 4-c. 复合模型转onnx(可选)
|
||||
|
||||
- yolo依原项目转换
|
||||
- dinov3直接使用官方即可,只是拿来提取特征
|
||||
- 分类的导出也是pytorch和onnx内置函数,由于太小所以没必要做什么操作
|
||||
- 之所以不组合为一个模型主要还是考虑到未来如果更新,可以复用,dinov3提取的特征实际上可以完成上述所有任务,虽然定位的模型会大一点,但一个几百kb大小的头可以替代九宫格的分类模型,点选定位如果只划分上下的部分或许也能用更小的模型
|
||||
|
||||
### 5-a.启动fastapi服务(必须要有训练完成的onnx格式模型)
|
||||
|
||||
运行 `python main.py`(默认用的paddle的onnx模型,如果要用resnet18可以自己改注释)或者`uvicorn main:app --host 0.0.0.0 --port 9645 --reload`
|
||||
基于环境变量可以自定义加载哪些模型,目前支持use_pdl,use_dfine,use_multi,日志也有环境变量改,LOG_LEVEL
|
||||
|
||||
由于轨迹问题,可能会出现验证正确但是结果失败,所以建议增加retry次数,训练后的paddle模型正确率在99.9%以上
|
||||
运行 `python main.py`(九宫格默认用的paddle的onnx模型,如果要用resnet18可以自己改注释)或者`uvicorn main:app --host 0.0.0.0 --port 9645 --reload`
|
||||
|
||||
由于轨迹问题,可能会出现验证正确但是结果失败,所以建议增加retry次数,训练后的paddle模型正确率在99.9%以上,
|
||||
由于dinov3的onnx和pth不完全一致,导致基于pth训练的头精度低了一点,4613张图里面错了1张,如果采用量化的dinov3会多错几张,当然,可以使用onnx量化提取的特征训练,但是速度比pth慢得多就是了,再说吧
|
||||
|
||||
### 5-b.使用docker启动服务
|
||||
|
||||
镜像地址为<span id="docker">luguoyixiazi/test_nine:25.7.2</span>
|
||||
镜像地址为<span id="docker">luguoyixiazi/test_nine:25.7.2</span> 此版本特供只需要九宫格和便笺的
|
||||
|
||||
运行时只需指定绑定的port和两个环境变量`use_pdl`和`use_dfine`,1为启用模型,0为不启用,默认均启用,api端口为/pass_uni,必填参数gt、challenge,单独的pass_nine和pass_icon也写了,有更多可选参数
|
||||
镜像地址为<span id="docker">luguoyixiazi/test_nine:25.11.2</span> 此版本在上述基础上支持🛤和3z
|
||||
|
||||
运行时只需指定绑定的port和两个环境变量`use_pdl`和`use_dfine`和`use_multi`,1为启用模型,0为不启用,默认均启用,api端口为/pass_uni,必填参数gt、challenge,单独的pass_nine和pass_icon也写了,有更多可选参数
|
||||
|
||||
### 6.api调用
|
||||
|
||||
@@ -103,8 +129,13 @@ def game_captcha(gt: str, challenge: str):
|
||||
|
||||
在snap hutao中的服务端口为/pass_hutao,返回值已做对齐,填写api如`(http://127.0.0.1:9645/pass_hutao?gt={0}&challenge={1})`即可
|
||||
|
||||
添加了辅助api加载卸载模型和查看模型,以及更改日志等级
|
||||
|
||||
具体调用代码看使用项目,此处示例仅为API url和参数示例
|
||||
|
||||
#### --宣传--
|
||||
|
||||
欢迎大家支持我的其他项目(搭配使用)喵~~~~~~~~
|
||||
|
||||
|
||||
|
||||
|
||||
96
crack.py
96
crack.py
@@ -1,16 +1,19 @@
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
|
||||
import logging
|
||||
import httpx
|
||||
from os import path as PATH
|
||||
from crop_image import validate_path
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives import padding, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
|
||||
from crop_image import validate_path
|
||||
from os import path as PATH
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Crack:
|
||||
def __init__(self, gt=None, challenge=None):
|
||||
self.pic_path = None
|
||||
@@ -39,6 +42,7 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
|
||||
url = f"https://api.geetest.com/gettype.php?gt={self.gt}"
|
||||
res = self.session.get(url)
|
||||
data = json.loads(res.text[1:-1])["data"]
|
||||
logger.debug(f"再次获得{data}")
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
@@ -358,6 +362,7 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
|
||||
}
|
||||
resp = self.session.get("https://api.geetest.com/get.php", params=params).text
|
||||
data = json.loads(resp[22:-1])["data"]
|
||||
logger.debug(f"获取cs结果{data}")
|
||||
self.c = data["c"]
|
||||
self.s = data["s"]
|
||||
return data["c"], data["s"]
|
||||
@@ -398,9 +403,79 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
|
||||
]
|
||||
tt = transform(self.encode_mouse_path(mouse_path, self.c, self.s), self.c, self.s)
|
||||
rp = self.MD5(self.gt + self.challenge + self.s)
|
||||
temp1 = '''"lang":"zh-cn","type":"fullpage","tt":"%s","light":"DIV_0","s":"c7c3e21112fe4f741921cb3e4ff9f7cb","h":"321f9af1e098233dbd03f250fd2b5e21","hh":"39bd9cad9e425c3a8f51610fd506e3b3","hi":"09eb21b3ae9542a9bc1e8b63b3d9a467","vip_order":-1,"ct":-1,"ep":{"v":"9.1.9-dbjg5z","te":false,"me":true,"ven":"Google Inc. (Intel)","ren":"ANGLE (Intel, Intel(R) Iris(R) Xe Graphics (0x0000A7A0) Direct3D11 vs_5_0 ps_5_0, D3D11)","fp":["scroll",0,1602,1724571628498,null],"lp":["up",386,217,1724571629854,"pointerup"],"em":{"ph":0,"cp":0,"ek":"11","wd":1,"nt":0,"si":0,"sc":0},"tm":{"a":1724571567311,"b":1724571567549,"c":1724571567562,"d":0,"e":0,"f":1724571567312,"g":1724571567312,"h":1724571567312,"i":1724571567317,"j":1724571567423,"k":1724571567330,"l":1724571567423,"m":1724571567545,"n":1724571567547,"o":1724571567569,"p":1724571568259,"q":1724571568259,"r":1724571568261,"s":1724571570378,"t":1724571570378,"u":1724571570380},"dnf":"dnf","by":0},"passtime":1600,"rp":"%s",''' % (
|
||||
tt, rp)
|
||||
r = "{" + temp1 + '"captcha_token":"1198034057","du6o":"eyjf7nne"}'
|
||||
payload_dict = {
|
||||
"lang": "zh-cn",
|
||||
"type": "fullpage",
|
||||
"tt": tt, # 使用变量
|
||||
"light": "DIV_0",
|
||||
"s": "c7c3e21112fe4f741921cb3e4ff9f7cb",
|
||||
"h": "321f9af1e098233dbd03f250fd2b5e21",
|
||||
"hh": "39bd9cad9e425c3a8f51610fd506e3b3",
|
||||
"hi": "09eb21b3ae9542a9bc1e8b63b3d9a467",
|
||||
"vip_order": -1,
|
||||
"ct": -1,
|
||||
"ep": {
|
||||
"v": "9.2.0-guwyxh",
|
||||
"te": False, # JSON 'false' 对应 Python 'False'
|
||||
"me": True, # JSON 'true' 对应 Python 'True'
|
||||
"ven": "Google Inc. (Intel)",
|
||||
"ren": "ANGLE (Intel, Intel(R) Iris(R) Xe Graphics (0x0000A7A0) Direct3D11 vs_5_0 ps_5_0, D3D11)",
|
||||
"fp": [
|
||||
"scroll",
|
||||
0,
|
||||
1602,
|
||||
1724571628498,
|
||||
None # JSON 'null' 对应 Python 'None'
|
||||
],
|
||||
"lp": [
|
||||
"up",
|
||||
386,
|
||||
217,
|
||||
1724571629854,
|
||||
"pointerup"
|
||||
],
|
||||
"em": {
|
||||
"ph": 0,
|
||||
"cp": 0,
|
||||
"ek": "11",
|
||||
"wd": 1,
|
||||
"nt": 0,
|
||||
"si": 0,
|
||||
"sc": 0
|
||||
},
|
||||
"tm": {
|
||||
"a": 1724571567311,
|
||||
"b": 1724571567549,
|
||||
"c": 1724571567562,
|
||||
"d": 0,
|
||||
"e": 0,
|
||||
"f": 1724571567312,
|
||||
"g": 1724571567312,
|
||||
"h": 1724571567312,
|
||||
"i": 1724571567317,
|
||||
"j": 1724571567423,
|
||||
"k": 1724571567330,
|
||||
"l": 1724571567423,
|
||||
"m": 1724571567545,
|
||||
"n": 1724571567547,
|
||||
"o": 1724571567569,
|
||||
"p": 1724571568259,
|
||||
"q": 1724571568259,
|
||||
"r": 1724571568261,
|
||||
"s": 1724571570378,
|
||||
"t": 1724571570378,
|
||||
"u": 1724571570380
|
||||
},
|
||||
"dnf": "dnf",
|
||||
"by": 0
|
||||
},
|
||||
"passtime": 1600,
|
||||
"rp": rp, # 使用变量
|
||||
"captcha_token":"2064329542",
|
||||
"tsfq":"xovrayel"
|
||||
}
|
||||
# r = "{" + temp1 + '"captcha_token":"1198034057","du6o":"eyjf7nne"}'
|
||||
r = json.dumps(payload_dict, separators=(',', ':'))
|
||||
ct = self.aes_encrypt(r)
|
||||
s = [byte for byte in ct]
|
||||
w = self.encode(s)
|
||||
@@ -414,6 +489,7 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
|
||||
"w": w
|
||||
}
|
||||
resp = self.session.get("https://api.geetest.com/ajax.php", params=params).text
|
||||
logger.debug(f"ajax结果:{resp}")
|
||||
return json.loads(resp[22:-1])["data"]
|
||||
|
||||
def get_pic(self):
|
||||
@@ -435,11 +511,15 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
|
||||
}
|
||||
resp = self.session.get("https://api.geevisit.com/get.php", params=params).text
|
||||
data = json.loads(resp[22:-1])["data"]
|
||||
logger.debug(f"获取图片结果{data}")
|
||||
self.pic_path = data["pic"]
|
||||
pic_url = "https://" + data["resource_servers"][0][:-1] + data["pic"]
|
||||
|
||||
pic_data = self.session.get(pic_url).content
|
||||
pic_name = data["pic"].split("/")[-1]
|
||||
pic_type = data["pic_type"]
|
||||
if "80/2023-12-04T12/icon" in pic_url:
|
||||
pic_type = "icon1"
|
||||
with open(PATH.join(validate_path,pic_name),'wb+') as f:
|
||||
f.write(pic_data)
|
||||
return pic_data,pic_name,pic_type
|
||||
@@ -483,4 +563,4 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
|
||||
"w": w
|
||||
}
|
||||
resp = self.session.get("https://api.geevisit.com/ajax.php", params=params).text
|
||||
return resp[1:-1]
|
||||
return resp[1:-1]
|
||||
147
main.py
147
main.py
@@ -4,45 +4,86 @@ import time
|
||||
import httpx
|
||||
import shutil
|
||||
import random
|
||||
import uvicorn
|
||||
import logging
|
||||
import logging.config
|
||||
from crack import Crack
|
||||
from typing import Optional, Dict, Any
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import FastAPI,Query, HTTPException
|
||||
from predict import predict_onnx,predict_onnx_pdl,predict_onnx_dfine
|
||||
from crop_image import crop_image_v3,save_path,save_fail_path,save_pass_path,validate_path
|
||||
|
||||
PORT = 9645
|
||||
platform = os.name
|
||||
import logging.handlers
|
||||
LOG_DIR = "logs"
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
LOG_FILENAME = os.path.join(LOG_DIR, "app.log")
|
||||
log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
|
||||
# --- 日志配置字典 ---
|
||||
LOGGING_CONFIG = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
# 定义一个名为 "default" 的格式化器
|
||||
"default": {
|
||||
"()": f"uvicorn.{'_logging' if platform== 'nt' else 'logging'}.DefaultFormatter",
|
||||
"fmt": "%(levelprefix)s %(asctime)s | %(message)s",
|
||||
"()": "logging.Formatter", # 使用标准的 Formatter
|
||||
"fmt": "[%(levelname)s - %(filename)s:%(lineno)d | %(funcName)s] - %(asctime)s - %(message)s",
|
||||
"datefmt": "%Y-%m-%d %H:%M:%S",
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"default": {
|
||||
"formatter": "default",
|
||||
# 控制台输出的 Handler
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "default", # 使用上面定义的 default 格式
|
||||
"stream": "ext://sys.stderr",
|
||||
},
|
||||
# 文件输出和轮转的 Handler
|
||||
"file_rotating": {
|
||||
"class": "logging.handlers.TimedRotatingFileHandler",
|
||||
"formatter": "default", # 也使用上面定义的 default 格式
|
||||
"filename": LOG_FILENAME, # 日志文件路径
|
||||
"when": "D", # 按天轮转 ('D' for Day)
|
||||
"interval": 1, # 每天轮转一次
|
||||
"backupCount": 2, # 保留2个旧的日志文件 (加上当前文件,总共覆盖3天)
|
||||
"encoding": "utf-8",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
# 将根日志记录器的级别设置为 INFO
|
||||
"": {"handlers": ["default"], "level": "INFO"},
|
||||
"uvicorn.error": {"level": "INFO"},
|
||||
"uvicorn.access": {"handlers": ["default"], "level": "INFO", "propagate": False},
|
||||
# 根日志记录器
|
||||
"": {
|
||||
# 同时将日志发送到 console 和 file_rotating 两个 Handler
|
||||
"handlers": ["console", "file_rotating"],
|
||||
"level": log_level,
|
||||
},
|
||||
# 针对 uvicorn 的日志记录器进行配置,确保它们也使用我们的设置
|
||||
"uvicorn": {
|
||||
"handlers": ["console", "file_rotating"],
|
||||
"level": "WARNING",
|
||||
"propagate": False, # 阻止 uvicorn 日志向上传播到根 logger,避免重复记录
|
||||
},
|
||||
"uvicorn.error": {
|
||||
"level": "WARNING",
|
||||
"propagate": True, # uvicorn.error 应该传播,以便根记录器可以捕获它
|
||||
},
|
||||
"uvicorn.access": {
|
||||
"handlers": ["console", "file_rotating"],
|
||||
"level": log_level,
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
logger = logging.getLogger(__name__)
|
||||
from crack import Crack
|
||||
from typing import Optional, Dict, Any
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import FastAPI,Query, HTTPException
|
||||
from predict import (predict_onnx,
|
||||
predict_onnx_pdl,
|
||||
predict_onnx_dfine,
|
||||
predict_dino_classify_pipeline,
|
||||
load_by,
|
||||
unload,
|
||||
get_models,
|
||||
get_available_models)
|
||||
from crop_image import crop_image_v3,save_path,save_fail_path,save_pass_path,validate_path
|
||||
|
||||
PORT = 9645
|
||||
|
||||
def get_available_hosts() -> set[str]:
|
||||
"""获取本机所有可用的IPv4地址。"""
|
||||
import socket
|
||||
@@ -70,7 +111,7 @@ async def lifespan(app: FastAPI):
|
||||
port = server.config.port if server else PORT
|
||||
if host == "0.0.0.0":
|
||||
available_hosts = get_available_hosts()
|
||||
logger.info(f"服务地址(依需求选用,docker中使用宿主机host:{port}):")
|
||||
logger.info(f"服务地址(依需求选用,docker中使用宿主机host:{port},若使用Uvicorn运行则基于命令):")
|
||||
for h in sorted(list(available_hosts)):
|
||||
logger.info(f" - http://{h}:{port}")
|
||||
else:
|
||||
@@ -78,6 +119,7 @@ async def lifespan(app: FastAPI):
|
||||
logger.info(f"可用服务路径如下:")
|
||||
for route in app.routes:
|
||||
logger.info(f" -{route.methods} {route.path}")
|
||||
logger.info(f"具体api使用可以查看/docs")
|
||||
logger.info("="*50)
|
||||
|
||||
yield
|
||||
@@ -89,9 +131,9 @@ app = FastAPI(title="极验V3图标点选+九宫格", lifespan=lifespan)
|
||||
|
||||
def prepare(gt: str, challenge: str) -> tuple[Crack, bytes, str, str]:
|
||||
"""获取信息。"""
|
||||
logging.info(f"开始获取:\ngt:{gt}\nchallenge:{challenge}")
|
||||
logger.info(f"开始获取:\ngt:{gt}\nchallenge:{challenge}")
|
||||
crack = Crack(gt, challenge)
|
||||
crack.gettype()
|
||||
logger.debug(f"初次获得{crack.gettype()}")
|
||||
crack.get_c_s()
|
||||
time.sleep(random.uniform(0.4,0.6))
|
||||
crack.ajax()
|
||||
@@ -114,7 +156,12 @@ def do_pass_nine(pic_content: bytes, use_v3_model: bool, point: Optional[str]) -
|
||||
def do_pass_icon(pic:Any, draw_result: bool) -> list[str]:
|
||||
"""处理图标点选验证码,返回坐标点列表。"""
|
||||
result_list = predict_onnx_dfine(pic,draw_result)
|
||||
print(result_list)
|
||||
logger.debug(result_list)
|
||||
return [f"{round(x / 333 * 10000)}_{round(y / 333 * 10000)}" for x, y in result_list]
|
||||
|
||||
def do_pass_icon0(pic:Any, draw_result: bool) -> list[str]:
|
||||
"""处理图标点选验证码,返回坐标点列表。"""
|
||||
result_list = predict_dino_classify_pipeline(pic,draw_result)
|
||||
return [f"{round(x / 333 * 10000)}_{round(y / 333 * 10000)}" for x, y in result_list]
|
||||
|
||||
def save_image_for_train(pic_name,pic_type,passed):
|
||||
@@ -141,14 +188,16 @@ def handle_pass_request(gt: str, challenge: str, save_result: bool, **kwargs) ->
|
||||
crack, pic_content, pic_name, pic_type = prepare(gt, challenge)
|
||||
|
||||
# 2. 识别
|
||||
|
||||
logger.debug(f"接收图片类型{pic_type}")
|
||||
if pic_type == "nine":
|
||||
point_list = do_pass_nine(
|
||||
pic_content,
|
||||
use_v3_model=kwargs.get("use_v3_model", True),
|
||||
point=kwargs.get("point",None)
|
||||
)
|
||||
elif pic_type == "icon":
|
||||
elif pic_type == 'icon': # dino
|
||||
point_list = do_pass_icon0(pic_content, save_result)
|
||||
elif pic_type == "icon1": # d-fine
|
||||
point_list = do_pass_icon(pic_content, save_result)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown picture type: {pic_type}")
|
||||
@@ -169,14 +218,14 @@ def handle_pass_request(gt: str, challenge: str, save_result: bool, **kwargs) ->
|
||||
os.remove(os.path.join(validate_path,pic_name))
|
||||
|
||||
total_time = time.monotonic() - start_time
|
||||
logging.info(
|
||||
logger.info(
|
||||
f"请求完成,耗时: {total_time:.2f}s (等待 {wait_time:.2f}s). "
|
||||
f"结果: {result}"
|
||||
)
|
||||
return JSONResponse(content=result)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"服务错误: {e}", exc_info=True)
|
||||
logger.error(f"服务错误: {e}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": "An internal server error occurred.", "detail": str(e)}
|
||||
@@ -227,12 +276,50 @@ def pass_hutao(gt: str = Query(...),
|
||||
return JSONResponse(content=rebuild_content, status_code=original_status_code)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"修改路由错误: {e}", exc_info=True)
|
||||
logger.error(f"修改路由错误: {e}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": "An internal server error occurred.", "detail": str(e)}
|
||||
)
|
||||
|
||||
@app.get("/list_model")
|
||||
def list_model():
|
||||
return JSONResponse(content = get_models())
|
||||
@app.get("/list_all_model")
|
||||
def list_model():
|
||||
return JSONResponse(content = get_available_models())
|
||||
@app.get("/load_model")
|
||||
def load_model(name: str = Query(...)):
|
||||
return JSONResponse(content = get_models())
|
||||
@app.get("/unload_model")
|
||||
def unload_model(name: str = Query(...)):
|
||||
return JSONResponse(content = get_models())
|
||||
|
||||
@app.get("/set_log_level")
|
||||
def set_log_level(level: str = Query(...)):
|
||||
"""
|
||||
在服务运行时动态修改所有主要 logger 的日志级别。
|
||||
例如: /set_log_level?level=DEBUG
|
||||
"""
|
||||
# 将字符串级别转换为 logging 模块对应的整数值
|
||||
level_str = str(level).upper()
|
||||
numeric_level = getattr(logging, level_str, None)
|
||||
if not isinstance(numeric_level, int):
|
||||
raise HTTPException(status_code=400, detail=f"无效的日志级别: {level}")
|
||||
|
||||
# 获取并修改您配置中所有关键 logger 的级别
|
||||
# 1. 修改根 logger
|
||||
logging.getLogger().setLevel(numeric_level)
|
||||
# 2. 修改 uvicorn logger
|
||||
logging.getLogger("uvicorn").setLevel(numeric_level)
|
||||
logging.getLogger("uvicorn.error").setLevel(numeric_level)
|
||||
logging.getLogger("uvicorn.access").setLevel(numeric_level)
|
||||
|
||||
# 记录一条高级别的日志,确保能被看到
|
||||
logger.warning(f"所有 logger 的日志级别已被动态修改为: {level_str}")
|
||||
|
||||
return JSONResponse(content = f"Log level successfully set to {level_str}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app,port=PORT)
|
||||
uvicorn.run(app,port=PORT,log_config=LOGGING_CONFIG)
|
||||
|
||||
469
predict.py
469
predict.py
@@ -1,12 +1,40 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image,bytes_to_pil,validate_path
|
||||
import time
|
||||
from PIL import Image, ImageDraw
|
||||
import base64
|
||||
import random
|
||||
import logging
|
||||
import numpy as np
|
||||
from io import BytesIO
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image,bytes_to_pil,validate_path,save_path
|
||||
logger = logging.getLogger(__name__)
|
||||
def safe_load_img(image):
|
||||
im_pil = None
|
||||
try:
|
||||
if isinstance(image, Image.Image):
|
||||
im_pil = image
|
||||
elif isinstance(image, str):
|
||||
try:
|
||||
im_pil = Image.open(image)
|
||||
except (IOError, FileNotFoundError):
|
||||
if ',' in image:
|
||||
image = image.split(',')[-1]
|
||||
padding = len(image) % 4
|
||||
if padding > 0:
|
||||
image += '=' * (4 - padding)
|
||||
img_bytes = base64.b64decode(image)
|
||||
im_pil = Image.open(io.BytesIO(img_bytes))
|
||||
elif isinstance(image, bytes):
|
||||
im_pil = bytes_to_pil(image)
|
||||
elif isinstance(image, np.ndarray):
|
||||
im_pil = Image.fromarray(image)
|
||||
else:
|
||||
raise ValueError(f"不支持的输入类型: {type(image)}")
|
||||
return im_pil.convert("RGB")
|
||||
except Exception as e:
|
||||
raise ValueError(f"无法加载或解析图像,错误: {e}")
|
||||
|
||||
def predict(icon_image, bg_image):
|
||||
import torch
|
||||
from train import MyResNet18, data_transform
|
||||
@@ -34,7 +62,7 @@ def predict(icon_image, bg_image):
|
||||
model = MyResNet18(num_classes=91) # 这里的类别数要与训练时一致
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
model.eval()
|
||||
print("加载模型,耗时:", time.time() - start)
|
||||
logger.info("加载模型,耗时:", time.time() - start)
|
||||
start = time.time()
|
||||
|
||||
target_images = torch.stack(target_images, dim=0)
|
||||
@@ -52,14 +80,14 @@ def predict(icon_image, bg_image):
|
||||
)
|
||||
scores.append(similarity.cpu().item())
|
||||
# 从左到右,从上到下,依次为每张图片的置信度
|
||||
print(scores)
|
||||
logger.info(scores)
|
||||
# 对数组进行排序,保持下标
|
||||
indexed_arr = list(enumerate(scores))
|
||||
sorted_arr = sorted(indexed_arr, key=lambda x: x[1], reverse=True)
|
||||
# 提取最大三个数及其下标
|
||||
largest_three = sorted_arr[:3]
|
||||
print(largest_three)
|
||||
print("识别完成,耗时:", time.time() - start)
|
||||
logger.info(largest_three)
|
||||
logger.info("识别完成,耗时:", time.time() - start)
|
||||
|
||||
def load_model(name='PP-HGNetV2-B4.onnx'):
|
||||
# 加载onnx模型
|
||||
@@ -69,7 +97,7 @@ def load_model(name='PP-HGNetV2-B4.onnx'):
|
||||
model_path = os.path.join(current_dir, 'model', name)
|
||||
session = ort.InferenceSession(model_path)
|
||||
input_name = session.get_inputs()[0].name
|
||||
print(f"加载{name}模型,耗时:{time.time() - start}")
|
||||
logger.info(f"加载{name}模型,耗时:{time.time() - start}")
|
||||
|
||||
def load_dfine_model(name='d-fine-n.onnx'):
|
||||
# 加载onnx模型
|
||||
@@ -78,8 +106,31 @@ def load_dfine_model(name='d-fine-n.onnx'):
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(current_dir, 'model', name)
|
||||
session_dfine = ort.InferenceSession(model_path)
|
||||
print(f"加载{name}模型,耗时:{time.time() - start}")
|
||||
logger.info(f"加载{name}模型,耗时:{time.time() - start}")
|
||||
|
||||
def load_yolo11n(name='yolo11n.onnx'):
|
||||
global session_yolo11n
|
||||
start = time.time()
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(current_dir, 'model', name)
|
||||
session_yolo11n = ort.InferenceSession(model_path)
|
||||
logger.info(f"加载{name}模型,耗时:{time.time() - start}")
|
||||
|
||||
def load_dinov3(name='dinov3-small.onnx'):
|
||||
global session_dino3
|
||||
start = time.time()
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(current_dir, 'model', name)
|
||||
session_dino3 = ort.InferenceSession(model_path)
|
||||
logger.info(f"加载{name}模型,耗时:{time.time() - start}")
|
||||
|
||||
def load_dino_classify(name='atten.onnx'):
|
||||
global session_dino_cf
|
||||
start = time.time()
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(current_dir, 'model', name)
|
||||
session_dino_cf = ort.InferenceSession(model_path)
|
||||
logger.info(f"加载{name}模型,耗时:{time.time() - start}")
|
||||
|
||||
def predict_onnx(icon_image, bg_image, point = None):
|
||||
import cv2
|
||||
@@ -137,8 +188,7 @@ def predict_onnx(icon_image, bg_image, point = None):
|
||||
else:
|
||||
similarity = cosine_similarity(target_output, out_put)
|
||||
scores.append(similarity)
|
||||
# 从左到右,从上到下,依次为每张图片的置信度
|
||||
# print(scores)
|
||||
logger.debug(f"从左到右,从上到下,依次为每张图片的置信度:\n{scores}")
|
||||
# 对数组进行排序,保持下标
|
||||
indexed_arr = list(enumerate(scores))
|
||||
sorted_arr = sorted(indexed_arr, key=lambda x: x[1], reverse=True)
|
||||
@@ -149,7 +199,7 @@ def predict_onnx(icon_image, bg_image, point = None):
|
||||
# 基于分数判断
|
||||
else:
|
||||
answer = [one[0] for one in sorted_arr if one[1] > point]
|
||||
print(f"识别完成{answer},耗时: {time.time() - start}")
|
||||
logger.info(f"识别完成{answer},耗时: {time.time() - start}")
|
||||
#draw_points_on_image(bg_image, answer)
|
||||
return answer
|
||||
|
||||
@@ -202,40 +252,79 @@ def predict_onnx_pdl(images_path):
|
||||
if len(answer) == 0:
|
||||
all_sort =[np.argsort(one) for one in outputs]
|
||||
answer = [coordinates[index] for index in range(9) if all_sort[index][1] == target]
|
||||
print(f"识别完成{answer},耗时: {time.time() - start}")
|
||||
logger.info(f"识别完成{answer},耗时: {time.time() - start}")
|
||||
with open(os.path.join(images_path,"nine.jpg"),'rb') as f:
|
||||
bg_image = f.read()
|
||||
# draw_points_on_image(bg_image, answer)
|
||||
#draw_points_on_image(bg_image, answer)
|
||||
return answer
|
||||
|
||||
|
||||
# d-fine的推理代码及函数
|
||||
def calculate_iou(boxA, boxB):
|
||||
"""
|
||||
使用 NumPy 计算两个边界框的交并比 (IoU)。
|
||||
"""
|
||||
# 确定相交矩形的坐标
|
||||
xA = np.maximum(boxA[0], boxB[0])
|
||||
yA = np.maximum(boxA[1], boxB[1])
|
||||
xB = np.minimum(boxA[2], boxB[2])
|
||||
yB = np.minimum(boxA[3], boxB[3])
|
||||
|
||||
# 计算相交区域的面积
|
||||
intersection_area = np.maximum(0, xB - xA) * np.maximum(0, yB - yA)
|
||||
|
||||
# 计算两个边界框的面积
|
||||
boxA_area = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
|
||||
boxB_area = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
|
||||
|
||||
# 计算并集面积
|
||||
union_area = float(boxA_area + boxB_area - intersection_area)
|
||||
|
||||
# 计算 IoU
|
||||
if union_area == 0:
|
||||
return 0.0
|
||||
return 0.0
|
||||
|
||||
iou = intersection_area / union_area
|
||||
return iou
|
||||
|
||||
def non_maximum_suppression(detections, iou_threshold=0.35):
|
||||
"""
|
||||
对检测结果执行非极大值抑制 (NMS)。
|
||||
|
||||
参数:
|
||||
detections -- 一个列表,其中每个元素是包含 'box', 'score' 的字典。
|
||||
例如: [{'box': [x1, y1, x2, y2], 'score': 0.9, ...}, ...]
|
||||
iou_threshold -- 一个浮点数,用于判断框是否重叠的 IoU 阈值。
|
||||
|
||||
返回:
|
||||
final_detections -- 经过 NMS 处理后保留下来的检测结果列表。
|
||||
"""
|
||||
# 1. 检查检测结果是否为空
|
||||
if not detections:
|
||||
return []
|
||||
|
||||
# 2. 按置信度(score)从高到低对边界框进行排序
|
||||
# 我们使用 lambda 函数来指定排序的键
|
||||
detections.sort(key=lambda x: x['score'], reverse=True)
|
||||
|
||||
final_detections = []
|
||||
|
||||
# 3. 循环处理,直到没有检测结果为止
|
||||
while detections:
|
||||
# 4. 将当前得分最高的检测结果(第一个)添加到最终列表中
|
||||
# 并将其从原始列表中移除
|
||||
best_detection = detections.pop(0)
|
||||
final_detections.append(best_detection)
|
||||
|
||||
# 5. 计算刚刚取出的最佳框与剩余所有框的 IoU
|
||||
# 并只保留那些 IoU 小于阈值的框
|
||||
detections_to_keep = []
|
||||
for det in detections:
|
||||
# 假设相同类别的才进行NMS
|
||||
iou = calculate_iou(best_detection['box'], det['box'])
|
||||
if iou < iou_threshold:
|
||||
detections_to_keep.append(det)
|
||||
|
||||
# 用筛选后的列表替换原始列表,进行下一轮迭代
|
||||
detections = detections_to_keep
|
||||
|
||||
return final_detections
|
||||
@@ -246,10 +335,7 @@ def predict_onnx_dfine(image,draw_result=False):
|
||||
image_input_name = input_nodes[0].name
|
||||
size_input_name = input_nodes[1].name
|
||||
output_names = [node.name for node in output_nodes]
|
||||
if isinstance(image,bytes):
|
||||
im_pil = bytes_to_pil(image)
|
||||
else:
|
||||
im_pil = Image.open(image).convert("RGB")
|
||||
im_pil = safe_load_img(image)
|
||||
w, h = im_pil.size
|
||||
orig_size_np = np.array([[w, h]], dtype=np.int64)
|
||||
im_resized = im_pil.resize((320, 320), Image.Resampling.BILINEAR)
|
||||
@@ -340,18 +426,335 @@ def predict_onnx_dfine(image,draw_result=False):
|
||||
text_position = (box[0] + 2, box[1] - 12 if box[1] > 12 else box[1] + 2)
|
||||
draw.text(text_position, text, fill=color)
|
||||
if draw_result:
|
||||
save_path = os.path.join(validate_path,"icon_result.jpg")
|
||||
im_pil.save(save_path)
|
||||
print(f"图片可视化结果保存在{save_path}")
|
||||
print(f"图片顺序的中心点{points}")
|
||||
save_path_temp = os.path.join(validate_path,"icon_result.jpg")
|
||||
im_pil.save(save_path_temp)
|
||||
logger.info(f"图片可视化结果暂时保存在{save_path_temp},运行完成后移至{save_path}")
|
||||
logger.info(f"图片顺序的中心点{points}")
|
||||
return points
|
||||
|
||||
|
||||
print(f"使用推理设备: {ort.get_device()}")
|
||||
# yolo的推理代码及函数
|
||||
def predict_onnx_yolo(image):
|
||||
def filter_Detections(results, thresh = 0.5):
|
||||
results = results[0]
|
||||
results = results.transpose()
|
||||
# if model is trained on 1 class only
|
||||
if len(results[0]) == 5:
|
||||
# filter out the detections with confidence > thresh
|
||||
considerable_detections = [detection for detection in results if detection[4] > thresh]
|
||||
considerable_detections = np.array(considerable_detections)
|
||||
return considerable_detections
|
||||
|
||||
# if model is trained on multiple classes
|
||||
else:
|
||||
A = []
|
||||
for detection in results:
|
||||
|
||||
class_id = detection[4:].argmax()
|
||||
confidence_score = detection[4:].max()
|
||||
|
||||
new_detection = np.append(detection[:4],[class_id,confidence_score])
|
||||
|
||||
A.append(new_detection)
|
||||
|
||||
A = np.array(A)
|
||||
|
||||
# filter out the detections with confidence > thresh
|
||||
considerable_detections = [detection for detection in A if detection[-1] > thresh]
|
||||
considerable_detections = np.array(considerable_detections)
|
||||
|
||||
return considerable_detections
|
||||
def NMS(boxes, conf_scores, iou_thresh = 0.55):
|
||||
|
||||
# boxes [[x1,y1, x2,y2], [x1,y1, x2,y2], ...]
|
||||
|
||||
x1 = boxes[:,0]
|
||||
y1 = boxes[:,1]
|
||||
x2 = boxes[:,2]
|
||||
y2 = boxes[:,3]
|
||||
|
||||
areas = (x2-x1)*(y2-y1)
|
||||
|
||||
order = conf_scores.argsort()
|
||||
|
||||
keep = []
|
||||
keep_confidences = []
|
||||
|
||||
while len(order) > 0:
|
||||
idx = order[-1]
|
||||
A = boxes[idx]
|
||||
conf = conf_scores[idx]
|
||||
|
||||
order = order[:-1]
|
||||
|
||||
xx1 = np.take(x1, indices= order)
|
||||
yy1 = np.take(y1, indices= order)
|
||||
xx2 = np.take(x2, indices= order)
|
||||
yy2 = np.take(y2, indices= order)
|
||||
|
||||
keep.append(A)
|
||||
keep_confidences.append(conf)
|
||||
|
||||
# iou = inter/union
|
||||
|
||||
xx1 = np.maximum(x1[idx], xx1)
|
||||
yy1 = np.maximum(y1[idx], yy1)
|
||||
xx2 = np.minimum(x2[idx], xx2)
|
||||
yy2 = np.minimum(y2[idx], yy2)
|
||||
|
||||
w = np.maximum(xx2-xx1, 0)
|
||||
h = np.maximum(yy2-yy1, 0)
|
||||
|
||||
intersection = w*h
|
||||
|
||||
# union = areaA + other_areas - intesection
|
||||
other_areas = np.take(areas, indices= order)
|
||||
union = areas[idx] + other_areas - intersection
|
||||
|
||||
iou = intersection/union
|
||||
|
||||
boleans = iou < iou_thresh
|
||||
|
||||
order = order[boleans]
|
||||
|
||||
# order = [2,0,1] boleans = [True, False, True]
|
||||
# order = [2,1]
|
||||
|
||||
return keep, keep_confidences
|
||||
def rescale_back(results,img_w,img_h,imgsz=384):
|
||||
cx, cy, w, h, class_id, confidence = results[:,0], results[:,1], results[:,2], results[:,3], results[:, 4], results[:,-1]
|
||||
cx = cx/imgsz * img_w
|
||||
cy = cy/imgsz * img_h
|
||||
w = w/imgsz * img_w
|
||||
h = h/imgsz * img_h
|
||||
x1 = cx - w/2
|
||||
y1 = cy - h/2
|
||||
x2 = cx + w/2
|
||||
y2 = cy + h/2
|
||||
|
||||
boxes = np.column_stack((x1, y1, x2, y2, class_id))
|
||||
keep, keep_confidences = NMS(boxes,confidence)
|
||||
return keep, keep_confidences
|
||||
im_pil = safe_load_img(image)
|
||||
im_resized = im_pil.resize((384, 384), Image.Resampling.BILINEAR)
|
||||
im_data = np.array(im_resized, dtype=np.float32) / 255.0
|
||||
im_data = im_data.transpose(2, 0, 1)
|
||||
im_data = np.expand_dims(im_data, axis=0)
|
||||
res = session_yolo11n.run(None,{"images":im_data})
|
||||
results = filter_Detections(res)
|
||||
rescaled_results, confidences = rescale_back(results,im_pil.size[0],im_pil.size[1])
|
||||
images = {"top":[],"bottom":[]}
|
||||
for r, conf in zip(rescaled_results, confidences):
|
||||
x1,y1,x2,y2, cls_id = r
|
||||
cls_id = int(cls_id)
|
||||
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
||||
cropped_image = im_pil.crop((x1, y1, x2, y2))
|
||||
if cls_id == 0:
|
||||
images['top'].append({"image":cropped_image,"bbox":[x1, y1, x2, y2]})
|
||||
else:
|
||||
images['bottom'].append({"image":cropped_image,"bbox":[x1, y1, x2, y2]})
|
||||
return images
|
||||
|
||||
# dinov3的推理代码及函数
|
||||
def make_lvd_transform(resize_size: int = 224):
|
||||
"""
|
||||
返回一个图像预处理函数,功能与PyTorch版本相同
|
||||
"""
|
||||
# 定义标准化参数
|
||||
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
||||
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||
|
||||
def transform(image) -> np.ndarray:
|
||||
"""
|
||||
图像预处理转换
|
||||
|
||||
Args:
|
||||
image: PIL Image 或 numpy array (H,W,C) 范围[0,255]
|
||||
|
||||
Returns:
|
||||
numpy array (C,H,W) 标准化后的float32数组
|
||||
"""
|
||||
# 确保输入是PIL图像
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image.astype('uint8'))
|
||||
|
||||
# 1. 调整大小 (使用LANCZOS抗锯齿,对应antialias=True)
|
||||
image = image.resize((resize_size, resize_size), Image.LANCZOS)
|
||||
|
||||
# 2. 转换为numpy数组并调整数据类型和范围
|
||||
# PIL图像转换为numpy数组 (H,W,C) 范围[0,255]
|
||||
image_array = np.array(image, dtype=np.float32)
|
||||
|
||||
# 如果图像是RGBA,只取RGB通道
|
||||
if image_array.shape[-1] == 4:
|
||||
image_array = image_array[:, :, :3]
|
||||
|
||||
# 缩放到[0,1]范围 (对应scale=True)
|
||||
image_array /= 255.0
|
||||
|
||||
# 3. 标准化
|
||||
# 注意:PyTorch的Normalize是逐通道进行的
|
||||
image_array = (image_array - mean) / std
|
||||
|
||||
# 4. 转换维度从 (H,W,C) 到 (C,H,W) - 与PyTorch张量布局一致
|
||||
image_array = np.transpose(image_array, (2, 0, 1))
|
||||
|
||||
return image_array
|
||||
|
||||
return transform
|
||||
transform = make_lvd_transform(224)
|
||||
|
||||
def predict_onnx_dino(image):
|
||||
im_pil = safe_load_img(image)
|
||||
input_name_model = session_dino3.get_inputs()[0].name
|
||||
output_name_model = session_dino3.get_outputs()[0].name
|
||||
return session_dino3.run([output_name_model],
|
||||
{input_name_model:
|
||||
np.expand_dims(transform(im_pil), axis=0).astype(np.float32)
|
||||
}
|
||||
)[0]
|
||||
|
||||
# dinov3结果分类的推理代码及函数
|
||||
def predict_dino_classify(tokens1,tokens2):
|
||||
patch_tokens1 = tokens1[:, 5:, :]
|
||||
patch_tokens2 = tokens2[:, 5:, :]
|
||||
input_name_model =session_dino_cf.get_inputs()[0].name
|
||||
output_name_model =session_dino_cf.get_outputs()[0].name
|
||||
emb1 = session_dino_cf.run([output_name_model], {input_name_model: patch_tokens1})[0]
|
||||
emb2 = session_dino_cf.run([output_name_model], {input_name_model: patch_tokens2})[0]
|
||||
emb1_flat = emb1.flatten()
|
||||
emb2_flat = emb2.flatten()
|
||||
return float(np.dot(emb1_flat, emb2_flat) / (np.linalg.norm(emb1_flat) * np.linalg.norm(emb2_flat)))
|
||||
|
||||
def predict_dino_classify_pipeline(image,draw_result=False):
|
||||
im_pil = safe_load_img(image)
|
||||
if draw_result:
|
||||
draw = ImageDraw.Draw(im_pil)
|
||||
crops = predict_onnx_yolo(im_pil)
|
||||
features = {}
|
||||
for k in crops:
|
||||
features.update({k:[]})
|
||||
for v in crops[k]:
|
||||
features[k].append({"feature":predict_onnx_dino(v['image']),"bbox":v['bbox']})
|
||||
features["bottom"] = sorted(features["bottom"], key=lambda x: x["bbox"][0])
|
||||
used_indices = set()
|
||||
sequence = []
|
||||
|
||||
for target in features['bottom']:
|
||||
available = [(idx, opt) for idx, opt in enumerate(features['top']) if idx not in used_indices]
|
||||
|
||||
if not available:
|
||||
break
|
||||
|
||||
if len(available) == 1:
|
||||
best_idx, best_opt = available[0]
|
||||
else:
|
||||
best_idx, best_opt = max(
|
||||
available,
|
||||
key=lambda item: predict_dino_classify(target['feature'], item[1]['feature'])
|
||||
)
|
||||
|
||||
sequence.append(best_opt['bbox'])
|
||||
used_indices.add(best_idx)
|
||||
colors = ["red", "blue", "green", "yellow", "white", "purple", "orange"]
|
||||
points = []
|
||||
for id,one in enumerate(sequence):
|
||||
center_x = (one[0] + one[2]) / 2
|
||||
center_y = (one[1] + one[3]) / 2
|
||||
w = abs(one[0] - one[2])
|
||||
y = abs(one[1] - one[3])
|
||||
points.append((center_x+random.randint(int(-w/5),int(w/5)),
|
||||
center_y+random.randint(int(-y/5),int(y/5))
|
||||
))
|
||||
if draw_result:
|
||||
draw.rectangle(one, outline=colors[id], width=1)
|
||||
text = f"{id+1}"
|
||||
text_position = (center_x, center_y)
|
||||
draw.text(text_position, text, fill='white')
|
||||
if draw_result:
|
||||
save_path_temp = os.path.join(validate_path,"icon_result.jpg")
|
||||
im_pil.save(save_path_temp)
|
||||
logger.info(f"图片可视化结果暂时保存在{save_path_temp},运行完成后移至{save_path}")
|
||||
return points
|
||||
|
||||
|
||||
logger.info(f"使用推理设备: {ort.get_device()}")
|
||||
def use_pdl():
|
||||
load_model()
|
||||
|
||||
def use_dfine():
|
||||
load_dfine_model()
|
||||
|
||||
def use_multi():
|
||||
load_yolo11n()
|
||||
load_dinov3()
|
||||
load_dino_classify()
|
||||
|
||||
model_for = [
|
||||
{"loader":use_pdl,
|
||||
"include":["session"],
|
||||
"support":['paddle','pdl','nine','原神','genshin']
|
||||
},
|
||||
{"loader":use_dfine,
|
||||
"include":["session_dfine"],
|
||||
"support":['dfine','click','memo','note','便笺']
|
||||
},
|
||||
{"loader":use_multi,
|
||||
"include":['session_yolo11n', 'session_dino3', 'session_dino_cf'],
|
||||
"support":['multi','dino','click2','星穹铁道','崩铁','绝区零','zzz','hkrpg']
|
||||
}
|
||||
]
|
||||
|
||||
def get_models():
|
||||
res = ["以下是当前加载的模型及其对应关键字"]
|
||||
for key,value in globals().items():
|
||||
if key.startswith("session") and value is not None:
|
||||
for one in model_for:
|
||||
if key in one['include']:
|
||||
res.append(f" -{key},关键词:{one['support']}")
|
||||
return res
|
||||
def get_available_models():
|
||||
res = ["以下是所有可用模型及其对应关键字"]
|
||||
for one in model_for:
|
||||
res.append(f" -{one['include']}关键词:{one['support']}")
|
||||
return res
|
||||
|
||||
def load_by(name):
|
||||
for one in model_for:
|
||||
if name in one['support'] or name in one['include']:
|
||||
one['loader']()
|
||||
return get_models()
|
||||
logger.error(f"不支持的名称,可以使用便笺、原神、崩铁、绝区零表示")
|
||||
|
||||
def unload(*names, safe_mode=True):
|
||||
import gc
|
||||
protected_vars = {'__name__', '__file__', '__builtins__',
|
||||
'unload'}
|
||||
for name in names:
|
||||
if name in globals():
|
||||
if safe_mode and name in protected_vars:
|
||||
logger.error(f"警告: 跳过保护变量 '{name}'")
|
||||
continue
|
||||
if not name.startswith('session'):
|
||||
logger.info("删除的不是模型!")
|
||||
var = globals()[name]
|
||||
if hasattr(var, 'close'):
|
||||
try:
|
||||
var.close()
|
||||
except:
|
||||
pass
|
||||
globals()[name] = None
|
||||
logger.info(f"已释放变量: {name}")
|
||||
collected = gc.collect()
|
||||
logger.info(f"垃圾回收器清理了 {collected} 个对象")
|
||||
return get_models()
|
||||
|
||||
if int(os.environ.get("use_pdl",1)):
|
||||
load_model()
|
||||
use_pdl()
|
||||
if int(os.environ.get("use_dfine",1)):
|
||||
load_dfine_model()
|
||||
use_dfine()
|
||||
if int(os.environ.get("use_multi",1)):
|
||||
use_multi()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 使用resnet18.onnx
|
||||
# load_model("resnet18.onnx")
|
||||
@@ -368,5 +771,7 @@ if __name__ == "__main__":
|
||||
|
||||
# 使用PP-HGNetV2-B4.onnx
|
||||
#predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf')
|
||||
predict_onnx_dfine(r"n:\爬点选\dataset\3f98ff0c91dd4882a8a24d451283ad96.jpg",True)
|
||||
|
||||
print(predict_onnx_dfine(r"f:\项目留档\JPEGImages\8bdee494b00d401aae3f496e76d886fc.jpg",True))
|
||||
# use_multi()
|
||||
# print(predict_dino_classify_pipeline("0a92e85f89b345279e74deaa9afa9e1c.jpg",True))
|
||||
|
||||
Reference in New Issue
Block a user