新模型及fullpage.9.2.0-guwyxh.js适配

Fixes #5
添加dinov3、任务头及yolo11n用于通过二者的验证码
同时更新fullpage.9.2.0-guwyxh.js中新的常量
``` json
{
captcha_token":"2064329542",
"tsfq":"xovrayel"
}
```
This commit is contained in:
luguoyixiazi
2025-11-04 16:40:02 +08:00
committed by GitHub
parent c26e354c2f
commit b4d35b9906
3 changed files with 642 additions and 70 deletions

View File

@@ -1,16 +1,19 @@
import os
import hashlib
import json
import math
import random
import time
import logging
import httpx
from os import path as PATH
from crop_image import validate_path
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from crop_image import validate_path
from os import path as PATH
import os
logger = logging.getLogger(__name__)
class Crack:
def __init__(self, gt=None, challenge=None):
self.pic_path = None
@@ -39,6 +42,7 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
url = f"https://api.geetest.com/gettype.php?gt={self.gt}"
res = self.session.get(url)
data = json.loads(res.text[1:-1])["data"]
logger.debug(f"再次获得{data}")
return data
@staticmethod
@@ -358,6 +362,7 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
}
resp = self.session.get("https://api.geetest.com/get.php", params=params).text
data = json.loads(resp[22:-1])["data"]
logger.debug(f"获取cs结果{data}")
self.c = data["c"]
self.s = data["s"]
return data["c"], data["s"]
@@ -398,9 +403,79 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
]
tt = transform(self.encode_mouse_path(mouse_path, self.c, self.s), self.c, self.s)
rp = self.MD5(self.gt + self.challenge + self.s)
temp1 = '''"lang":"zh-cn","type":"fullpage","tt":"%s","light":"DIV_0","s":"c7c3e21112fe4f741921cb3e4ff9f7cb","h":"321f9af1e098233dbd03f250fd2b5e21","hh":"39bd9cad9e425c3a8f51610fd506e3b3","hi":"09eb21b3ae9542a9bc1e8b63b3d9a467","vip_order":-1,"ct":-1,"ep":{"v":"9.1.9-dbjg5z","te":false,"me":true,"ven":"Google Inc. (Intel)","ren":"ANGLE (Intel, Intel(R) Iris(R) Xe Graphics (0x0000A7A0) Direct3D11 vs_5_0 ps_5_0, D3D11)","fp":["scroll",0,1602,1724571628498,null],"lp":["up",386,217,1724571629854,"pointerup"],"em":{"ph":0,"cp":0,"ek":"11","wd":1,"nt":0,"si":0,"sc":0},"tm":{"a":1724571567311,"b":1724571567549,"c":1724571567562,"d":0,"e":0,"f":1724571567312,"g":1724571567312,"h":1724571567312,"i":1724571567317,"j":1724571567423,"k":1724571567330,"l":1724571567423,"m":1724571567545,"n":1724571567547,"o":1724571567569,"p":1724571568259,"q":1724571568259,"r":1724571568261,"s":1724571570378,"t":1724571570378,"u":1724571570380},"dnf":"dnf","by":0},"passtime":1600,"rp":"%s",''' % (
tt, rp)
r = "{" + temp1 + '"captcha_token":"1198034057","du6o":"eyjf7nne"}'
payload_dict = {
"lang": "zh-cn",
"type": "fullpage",
"tt": tt, # 使用变量
"light": "DIV_0",
"s": "c7c3e21112fe4f741921cb3e4ff9f7cb",
"h": "321f9af1e098233dbd03f250fd2b5e21",
"hh": "39bd9cad9e425c3a8f51610fd506e3b3",
"hi": "09eb21b3ae9542a9bc1e8b63b3d9a467",
"vip_order": -1,
"ct": -1,
"ep": {
"v": "9.2.0-guwyxh",
"te": False, # JSON 'false' 对应 Python 'False'
"me": True, # JSON 'true' 对应 Python 'True'
"ven": "Google Inc. (Intel)",
"ren": "ANGLE (Intel, Intel(R) Iris(R) Xe Graphics (0x0000A7A0) Direct3D11 vs_5_0 ps_5_0, D3D11)",
"fp": [
"scroll",
0,
1602,
1724571628498,
None # JSON 'null' 对应 Python 'None'
],
"lp": [
"up",
386,
217,
1724571629854,
"pointerup"
],
"em": {
"ph": 0,
"cp": 0,
"ek": "11",
"wd": 1,
"nt": 0,
"si": 0,
"sc": 0
},
"tm": {
"a": 1724571567311,
"b": 1724571567549,
"c": 1724571567562,
"d": 0,
"e": 0,
"f": 1724571567312,
"g": 1724571567312,
"h": 1724571567312,
"i": 1724571567317,
"j": 1724571567423,
"k": 1724571567330,
"l": 1724571567423,
"m": 1724571567545,
"n": 1724571567547,
"o": 1724571567569,
"p": 1724571568259,
"q": 1724571568259,
"r": 1724571568261,
"s": 1724571570378,
"t": 1724571570378,
"u": 1724571570380
},
"dnf": "dnf",
"by": 0
},
"passtime": 1600,
"rp": rp, # 使用变量
"captcha_token":"2064329542",
"tsfq":"xovrayel"
}
# r = "{" + temp1 + '"captcha_token":"1198034057","du6o":"eyjf7nne"}'
r = json.dumps(payload_dict, separators=(',', ':'))
ct = self.aes_encrypt(r)
s = [byte for byte in ct]
w = self.encode(s)
@@ -414,6 +489,7 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
"w": w
}
resp = self.session.get("https://api.geetest.com/ajax.php", params=params).text
logger.debug(f"ajax结果:{resp}")
return json.loads(resp[22:-1])["data"]
def get_pic(self):
@@ -435,11 +511,15 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
}
resp = self.session.get("https://api.geevisit.com/get.php", params=params).text
data = json.loads(resp[22:-1])["data"]
logger.debug(f"获取图片结果{data}")
self.pic_path = data["pic"]
pic_url = "https://" + data["resource_servers"][0][:-1] + data["pic"]
pic_data = self.session.get(pic_url).content
pic_name = data["pic"].split("/")[-1]
pic_type = data["pic_type"]
if "80/2023-12-04T12/icon" in pic_url:
pic_type = "icon1"
with open(PATH.join(validate_path,pic_name),'wb+') as f:
f.write(pic_data)
return pic_data,pic_name,pic_type
@@ -483,4 +563,4 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
"w": w
}
resp = self.session.get("https://api.geevisit.com/ajax.php", params=params).text
return resp[1:-1]
return resp[1:-1]

147
main.py
View File

@@ -4,45 +4,86 @@ import time
import httpx
import shutil
import random
import uvicorn
import logging
import logging.config
from crack import Crack
from typing import Optional, Dict, Any
from contextlib import asynccontextmanager
from fastapi.responses import JSONResponse
from fastapi import FastAPI,Query, HTTPException
from predict import predict_onnx,predict_onnx_pdl,predict_onnx_dfine
from crop_image import crop_image_v3,save_path,save_fail_path,save_pass_path,validate_path
PORT = 9645
platform = os.name
import logging.handlers
LOG_DIR = "logs"
os.makedirs(LOG_DIR, exist_ok=True)
LOG_FILENAME = os.path.join(LOG_DIR, "app.log")
log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
# --- 日志配置字典 ---
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
# 定义一个名为 "default" 的格式化器
"default": {
"()": f"uvicorn.{'_logging' if platform== 'nt' else 'logging'}.DefaultFormatter",
"fmt": "%(levelprefix)s %(asctime)s | %(message)s",
"()": "logging.Formatter", # 使用标准的 Formatter
"fmt": "[%(levelname)s - %(filename)s:%(lineno)d | %(funcName)s] - %(asctime)s - %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S",
},
},
"handlers": {
"default": {
"formatter": "default",
# 控制台输出的 Handler
"console": {
"class": "logging.StreamHandler",
"formatter": "default", # 使用上面定义的 default 格式
"stream": "ext://sys.stderr",
},
# 文件输出和轮转的 Handler
"file_rotating": {
"class": "logging.handlers.TimedRotatingFileHandler",
"formatter": "default", # 也使用上面定义的 default 格式
"filename": LOG_FILENAME, # 日志文件路径
"when": "D", # 按天轮转 ('D' for Day)
"interval": 1, # 每天轮转一次
"backupCount": 2, # 保留2个旧的日志文件 (加上当前文件总共覆盖3天)
"encoding": "utf-8",
},
},
"loggers": {
# 根日志记录器的级别设置为 INFO
"": {"handlers": ["default"], "level": "INFO"},
"uvicorn.error": {"level": "INFO"},
"uvicorn.access": {"handlers": ["default"], "level": "INFO", "propagate": False},
# 根日志记录器
"": {
# 同时将日志发送到 console 和 file_rotating 两个 Handler
"handlers": ["console", "file_rotating"],
"level": log_level,
},
# 针对 uvicorn 的日志记录器进行配置,确保它们也使用我们的设置
"uvicorn": {
"handlers": ["console", "file_rotating"],
"level": "WARNING",
"propagate": False, # 阻止 uvicorn 日志向上传播到根 logger避免重复记录
},
"uvicorn.error": {
"level": "WARNING",
"propagate": True, # uvicorn.error 应该传播,以便根记录器可以捕获它
},
"uvicorn.access": {
"handlers": ["console", "file_rotating"],
"level": log_level,
"propagate": False,
},
},
}
logging.config.dictConfig(LOGGING_CONFIG)
logger = logging.getLogger(__name__)
from crack import Crack
from typing import Optional, Dict, Any
from contextlib import asynccontextmanager
from fastapi.responses import JSONResponse
from fastapi import FastAPI,Query, HTTPException
from predict import (predict_onnx,
predict_onnx_pdl,
predict_onnx_dfine,
predict_dino_classify_pipeline,
load_by,
unload,
get_models,
get_available_models)
from crop_image import crop_image_v3,save_path,save_fail_path,save_pass_path,validate_path
PORT = 9645
def get_available_hosts() -> set[str]:
"""获取本机所有可用的IPv4地址。"""
import socket
@@ -70,7 +111,7 @@ async def lifespan(app: FastAPI):
port = server.config.port if server else PORT
if host == "0.0.0.0":
available_hosts = get_available_hosts()
logger.info(f"服务地址(依需求选用docker中使用宿主机host:{port}):")
logger.info(f"服务地址(依需求选用docker中使用宿主机host:{port}若使用Uvicorn运行则基于命令):")
for h in sorted(list(available_hosts)):
logger.info(f" - http://{h}:{port}")
else:
@@ -78,6 +119,7 @@ async def lifespan(app: FastAPI):
logger.info(f"可用服务路径如下:")
for route in app.routes:
logger.info(f" -{route.methods} {route.path}")
logger.info(f"具体api使用可以查看/docs")
logger.info("="*50)
yield
@@ -89,9 +131,9 @@ app = FastAPI(title="极验V3图标点选+九宫格", lifespan=lifespan)
def prepare(gt: str, challenge: str) -> tuple[Crack, bytes, str, str]:
"""获取信息。"""
logging.info(f"开始获取:\ngt:{gt}\nchallenge:{challenge}")
logger.info(f"开始获取:\ngt:{gt}\nchallenge:{challenge}")
crack = Crack(gt, challenge)
crack.gettype()
logger.debug(f"初次获得{crack.gettype()}")
crack.get_c_s()
time.sleep(random.uniform(0.4,0.6))
crack.ajax()
@@ -114,7 +156,12 @@ def do_pass_nine(pic_content: bytes, use_v3_model: bool, point: Optional[str]) -
def do_pass_icon(pic:Any, draw_result: bool) -> list[str]:
"""处理图标点选验证码,返回坐标点列表。"""
result_list = predict_onnx_dfine(pic,draw_result)
print(result_list)
logger.debug(result_list)
return [f"{round(x / 333 * 10000)}_{round(y / 333 * 10000)}" for x, y in result_list]
def do_pass_icon0(pic:Any, draw_result: bool) -> list[str]:
"""处理图标点选验证码,返回坐标点列表。"""
result_list = predict_dino_classify_pipeline(pic,draw_result)
return [f"{round(x / 333 * 10000)}_{round(y / 333 * 10000)}" for x, y in result_list]
def save_image_for_train(pic_name,pic_type,passed):
@@ -141,14 +188,16 @@ def handle_pass_request(gt: str, challenge: str, save_result: bool, **kwargs) ->
crack, pic_content, pic_name, pic_type = prepare(gt, challenge)
# 2. 识别
logger.debug(f"接收图片类型{pic_type}")
if pic_type == "nine":
point_list = do_pass_nine(
pic_content,
use_v3_model=kwargs.get("use_v3_model", True),
point=kwargs.get("point",None)
)
elif pic_type == "icon":
elif pic_type == 'icon': # dino
point_list = do_pass_icon0(pic_content, save_result)
elif pic_type == "icon1": # d-fine
point_list = do_pass_icon(pic_content, save_result)
else:
raise HTTPException(status_code=400, detail=f"Unknown picture type: {pic_type}")
@@ -169,14 +218,14 @@ def handle_pass_request(gt: str, challenge: str, save_result: bool, **kwargs) ->
os.remove(os.path.join(validate_path,pic_name))
total_time = time.monotonic() - start_time
logging.info(
logger.info(
f"请求完成,耗时: {total_time:.2f}s (等待 {wait_time:.2f}s). "
f"结果: {result}"
)
return JSONResponse(content=result)
except Exception as e:
logging.error(f"服务错误: {e}", exc_info=True)
logger.error(f"服务错误: {e}", exc_info=True)
return JSONResponse(
status_code=500,
content={"error": "An internal server error occurred.", "detail": str(e)}
@@ -227,12 +276,50 @@ def pass_hutao(gt: str = Query(...),
return JSONResponse(content=rebuild_content, status_code=original_status_code)
except Exception as e:
logging.error(f"修改路由错误: {e}", exc_info=True)
logger.error(f"修改路由错误: {e}", exc_info=True)
return JSONResponse(
status_code=500,
content={"error": "An internal server error occurred.", "detail": str(e)}
)
@app.get("/list_model")
def list_model():
return JSONResponse(content = get_models())
@app.get("/list_all_model")
def list_model():
return JSONResponse(content = get_available_models())
@app.get("/load_model")
def load_model(name: str = Query(...)):
return JSONResponse(content = get_models())
@app.get("/unload_model")
def unload_model(name: str = Query(...)):
return JSONResponse(content = get_models())
@app.get("/set_log_level")
def set_log_level(level: str = Query(...)):
"""
在服务运行时动态修改所有主要 logger 的日志级别。
例如: /set_log_level?level=DEBUG
"""
# 将字符串级别转换为 logging 模块对应的整数值
level_str = str(level).upper()
numeric_level = getattr(logging, level_str, None)
if not isinstance(numeric_level, int):
raise HTTPException(status_code=400, detail=f"无效的日志级别: {level}")
# 获取并修改您配置中所有关键 logger 的级别
# 1. 修改根 logger
logging.getLogger().setLevel(numeric_level)
# 2. 修改 uvicorn logger
logging.getLogger("uvicorn").setLevel(numeric_level)
logging.getLogger("uvicorn.error").setLevel(numeric_level)
logging.getLogger("uvicorn.access").setLevel(numeric_level)
# 记录一条高级别的日志,确保能被看到
logger.warning(f"所有 logger 的日志级别已被动态修改为: {level_str}")
return JSONResponse(content = f"Log level successfully set to {level_str}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app,port=PORT)
uvicorn.run(app,port=PORT,log_config=LOGGING_CONFIG)

View File

@@ -1,12 +1,40 @@
import os
import numpy as np
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image,bytes_to_pil,validate_path
import time
from PIL import Image, ImageDraw
import base64
import random
import logging
import numpy as np
from io import BytesIO
import onnxruntime as ort
from PIL import Image, ImageDraw
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image,bytes_to_pil,validate_path,save_path
logger = logging.getLogger(__name__)
def safe_load_img(image):
im_pil = None
try:
if isinstance(image, Image.Image):
im_pil = image
elif isinstance(image, str):
try:
im_pil = Image.open(image)
except (IOError, FileNotFoundError):
if ',' in image:
image = image.split(',')[-1]
padding = len(image) % 4
if padding > 0:
image += '=' * (4 - padding)
img_bytes = base64.b64decode(image)
im_pil = Image.open(io.BytesIO(img_bytes))
elif isinstance(image, bytes):
im_pil = bytes_to_pil(image)
elif isinstance(image, np.ndarray):
im_pil = Image.fromarray(image)
else:
raise ValueError(f"不支持的输入类型: {type(image)}")
return im_pil.convert("RGB")
except Exception as e:
raise ValueError(f"无法加载或解析图像,错误: {e}")
def predict(icon_image, bg_image):
import torch
from train import MyResNet18, data_transform
@@ -34,7 +62,7 @@ def predict(icon_image, bg_image):
model = MyResNet18(num_classes=91) # 这里的类别数要与训练时一致
model.load_state_dict(torch.load(model_path))
model.eval()
print("加载模型,耗时:", time.time() - start)
logger.info("加载模型,耗时:", time.time() - start)
start = time.time()
target_images = torch.stack(target_images, dim=0)
@@ -52,14 +80,14 @@ def predict(icon_image, bg_image):
)
scores.append(similarity.cpu().item())
# 从左到右,从上到下,依次为每张图片的置信度
print(scores)
logger.info(scores)
# 对数组进行排序,保持下标
indexed_arr = list(enumerate(scores))
sorted_arr = sorted(indexed_arr, key=lambda x: x[1], reverse=True)
# 提取最大三个数及其下标
largest_three = sorted_arr[:3]
print(largest_three)
print("识别完成,耗时:", time.time() - start)
logger.info(largest_three)
logger.info("识别完成,耗时:", time.time() - start)
def load_model(name='PP-HGNetV2-B4.onnx'):
# 加载onnx模型
@@ -69,7 +97,7 @@ def load_model(name='PP-HGNetV2-B4.onnx'):
model_path = os.path.join(current_dir, 'model', name)
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
print(f"加载{name}模型,耗时:{time.time() - start}")
logger.info(f"加载{name}模型,耗时:{time.time() - start}")
def load_dfine_model(name='d-fine-n.onnx'):
# 加载onnx模型
@@ -78,8 +106,31 @@ def load_dfine_model(name='d-fine-n.onnx'):
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', name)
session_dfine = ort.InferenceSession(model_path)
print(f"加载{name}模型,耗时:{time.time() - start}")
logger.info(f"加载{name}模型,耗时:{time.time() - start}")
def load_yolo11n(name='yolo11n.onnx'):
global session_yolo11n
start = time.time()
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', name)
session_yolo11n = ort.InferenceSession(model_path)
logger.info(f"加载{name}模型,耗时:{time.time() - start}")
def load_dinov3(name='dinov3-small.onnx'):
global session_dino3
start = time.time()
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', name)
session_dino3 = ort.InferenceSession(model_path)
logger.info(f"加载{name}模型,耗时:{time.time() - start}")
def load_dino_classify(name='atten.onnx'):
global session_dino_cf
start = time.time()
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', name)
session_dino_cf = ort.InferenceSession(model_path)
logger.info(f"加载{name}模型,耗时:{time.time() - start}")
def predict_onnx(icon_image, bg_image, point = None):
import cv2
@@ -137,8 +188,7 @@ def predict_onnx(icon_image, bg_image, point = None):
else:
similarity = cosine_similarity(target_output, out_put)
scores.append(similarity)
# 从左到右,从上到下,依次为每张图片的置信度
# print(scores)
logger.debug(f"从左到右,从上到下,依次为每张图片的置信度:\n{scores}")
# 对数组进行排序,保持下标
indexed_arr = list(enumerate(scores))
sorted_arr = sorted(indexed_arr, key=lambda x: x[1], reverse=True)
@@ -149,7 +199,7 @@ def predict_onnx(icon_image, bg_image, point = None):
# 基于分数判断
else:
answer = [one[0] for one in sorted_arr if one[1] > point]
print(f"识别完成{answer},耗时: {time.time() - start}")
logger.info(f"识别完成{answer},耗时: {time.time() - start}")
#draw_points_on_image(bg_image, answer)
return answer
@@ -202,40 +252,79 @@ def predict_onnx_pdl(images_path):
if len(answer) == 0:
all_sort =[np.argsort(one) for one in outputs]
answer = [coordinates[index] for index in range(9) if all_sort[index][1] == target]
print(f"识别完成{answer},耗时: {time.time() - start}")
logger.info(f"识别完成{answer},耗时: {time.time() - start}")
with open(os.path.join(images_path,"nine.jpg"),'rb') as f:
bg_image = f.read()
# draw_points_on_image(bg_image, answer)
#draw_points_on_image(bg_image, answer)
return answer
# d-fine的推理代码及函数
def calculate_iou(boxA, boxB):
"""
使用 NumPy 计算两个边界框的交并比 (IoU)。
"""
# 确定相交矩形的坐标
xA = np.maximum(boxA[0], boxB[0])
yA = np.maximum(boxA[1], boxB[1])
xB = np.minimum(boxA[2], boxB[2])
yB = np.minimum(boxA[3], boxB[3])
# 计算相交区域的面积
intersection_area = np.maximum(0, xB - xA) * np.maximum(0, yB - yA)
# 计算两个边界框的面积
boxA_area = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
boxB_area = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
# 计算并集面积
union_area = float(boxA_area + boxB_area - intersection_area)
# 计算 IoU
if union_area == 0:
return 0.0
return 0.0
iou = intersection_area / union_area
return iou
def non_maximum_suppression(detections, iou_threshold=0.35):
"""
对检测结果执行非极大值抑制 (NMS)。
参数:
detections -- 一个列表,其中每个元素是包含 'box', 'score' 的字典。
例如: [{'box': [x1, y1, x2, y2], 'score': 0.9, ...}, ...]
iou_threshold -- 一个浮点数,用于判断框是否重叠的 IoU 阈值。
返回:
final_detections -- 经过 NMS 处理后保留下来的检测结果列表。
"""
# 1. 检查检测结果是否为空
if not detections:
return []
# 2. 按置信度score从高到低对边界框进行排序
# 我们使用 lambda 函数来指定排序的键
detections.sort(key=lambda x: x['score'], reverse=True)
final_detections = []
# 3. 循环处理,直到没有检测结果为止
while detections:
# 4. 将当前得分最高的检测结果(第一个)添加到最终列表中
# 并将其从原始列表中移除
best_detection = detections.pop(0)
final_detections.append(best_detection)
# 5. 计算刚刚取出的最佳框与剩余所有框的 IoU
# 并只保留那些 IoU 小于阈值的框
detections_to_keep = []
for det in detections:
# 假设相同类别的才进行NMS
iou = calculate_iou(best_detection['box'], det['box'])
if iou < iou_threshold:
detections_to_keep.append(det)
# 用筛选后的列表替换原始列表,进行下一轮迭代
detections = detections_to_keep
return final_detections
@@ -246,10 +335,7 @@ def predict_onnx_dfine(image,draw_result=False):
image_input_name = input_nodes[0].name
size_input_name = input_nodes[1].name
output_names = [node.name for node in output_nodes]
if isinstance(image,bytes):
im_pil = bytes_to_pil(image)
else:
im_pil = Image.open(image).convert("RGB")
im_pil = safe_load_img(image)
w, h = im_pil.size
orig_size_np = np.array([[w, h]], dtype=np.int64)
im_resized = im_pil.resize((320, 320), Image.Resampling.BILINEAR)
@@ -340,18 +426,335 @@ def predict_onnx_dfine(image,draw_result=False):
text_position = (box[0] + 2, box[1] - 12 if box[1] > 12 else box[1] + 2)
draw.text(text_position, text, fill=color)
if draw_result:
save_path = os.path.join(validate_path,"icon_result.jpg")
im_pil.save(save_path)
print(f"图片可视化结果保存在{save_path}")
print(f"图片顺序的中心点{points}")
save_path_temp = os.path.join(validate_path,"icon_result.jpg")
im_pil.save(save_path_temp)
logger.info(f"图片可视化结果暂时保存在{save_path_temp},运行完成后移至{save_path}")
logger.info(f"图片顺序的中心点{points}")
return points
print(f"使用推理设备: {ort.get_device()}")
# yolo的推理代码及函数
def predict_onnx_yolo(image):
def filter_Detections(results, thresh = 0.5):
results = results[0]
results = results.transpose()
# if model is trained on 1 class only
if len(results[0]) == 5:
# filter out the detections with confidence > thresh
considerable_detections = [detection for detection in results if detection[4] > thresh]
considerable_detections = np.array(considerable_detections)
return considerable_detections
# if model is trained on multiple classes
else:
A = []
for detection in results:
class_id = detection[4:].argmax()
confidence_score = detection[4:].max()
new_detection = np.append(detection[:4],[class_id,confidence_score])
A.append(new_detection)
A = np.array(A)
# filter out the detections with confidence > thresh
considerable_detections = [detection for detection in A if detection[-1] > thresh]
considerable_detections = np.array(considerable_detections)
return considerable_detections
def NMS(boxes, conf_scores, iou_thresh = 0.55):
# boxes [[x1,y1, x2,y2], [x1,y1, x2,y2], ...]
x1 = boxes[:,0]
y1 = boxes[:,1]
x2 = boxes[:,2]
y2 = boxes[:,3]
areas = (x2-x1)*(y2-y1)
order = conf_scores.argsort()
keep = []
keep_confidences = []
while len(order) > 0:
idx = order[-1]
A = boxes[idx]
conf = conf_scores[idx]
order = order[:-1]
xx1 = np.take(x1, indices= order)
yy1 = np.take(y1, indices= order)
xx2 = np.take(x2, indices= order)
yy2 = np.take(y2, indices= order)
keep.append(A)
keep_confidences.append(conf)
# iou = inter/union
xx1 = np.maximum(x1[idx], xx1)
yy1 = np.maximum(y1[idx], yy1)
xx2 = np.minimum(x2[idx], xx2)
yy2 = np.minimum(y2[idx], yy2)
w = np.maximum(xx2-xx1, 0)
h = np.maximum(yy2-yy1, 0)
intersection = w*h
# union = areaA + other_areas - intesection
other_areas = np.take(areas, indices= order)
union = areas[idx] + other_areas - intersection
iou = intersection/union
boleans = iou < iou_thresh
order = order[boleans]
# order = [2,0,1] boleans = [True, False, True]
# order = [2,1]
return keep, keep_confidences
def rescale_back(results,img_w,img_h,imgsz=384):
cx, cy, w, h, class_id, confidence = results[:,0], results[:,1], results[:,2], results[:,3], results[:, 4], results[:,-1]
cx = cx/imgsz * img_w
cy = cy/imgsz * img_h
w = w/imgsz * img_w
h = h/imgsz * img_h
x1 = cx - w/2
y1 = cy - h/2
x2 = cx + w/2
y2 = cy + h/2
boxes = np.column_stack((x1, y1, x2, y2, class_id))
keep, keep_confidences = NMS(boxes,confidence)
return keep, keep_confidences
im_pil = safe_load_img(image)
im_resized = im_pil.resize((384, 384), Image.Resampling.BILINEAR)
im_data = np.array(im_resized, dtype=np.float32) / 255.0
im_data = im_data.transpose(2, 0, 1)
im_data = np.expand_dims(im_data, axis=0)
res = session_yolo11n.run(None,{"images":im_data})
results = filter_Detections(res)
rescaled_results, confidences = rescale_back(results,im_pil.size[0],im_pil.size[1])
images = {"top":[],"bottom":[]}
for r, conf in zip(rescaled_results, confidences):
x1,y1,x2,y2, cls_id = r
cls_id = int(cls_id)
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
cropped_image = im_pil.crop((x1, y1, x2, y2))
if cls_id == 0:
images['top'].append({"image":cropped_image,"bbox":[x1, y1, x2, y2]})
else:
images['bottom'].append({"image":cropped_image,"bbox":[x1, y1, x2, y2]})
return images
# dinov3的推理代码及函数
def make_lvd_transform(resize_size: int = 224):
"""
返回一个图像预处理函数功能与PyTorch版本相同
"""
# 定义标准化参数
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
def transform(image) -> np.ndarray:
"""
图像预处理转换
Args:
image: PIL Image 或 numpy array (H,W,C) 范围[0,255]
Returns:
numpy array (C,H,W) 标准化后的float32数组
"""
# 确保输入是PIL图像
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype('uint8'))
# 1. 调整大小 (使用LANCZOS抗锯齿对应antialias=True)
image = image.resize((resize_size, resize_size), Image.LANCZOS)
# 2. 转换为numpy数组并调整数据类型和范围
# PIL图像转换为numpy数组 (H,W,C) 范围[0,255]
image_array = np.array(image, dtype=np.float32)
# 如果图像是RGBA只取RGB通道
if image_array.shape[-1] == 4:
image_array = image_array[:, :, :3]
# 缩放到[0,1]范围 (对应scale=True)
image_array /= 255.0
# 3. 标准化
# 注意PyTorch的Normalize是逐通道进行的
image_array = (image_array - mean) / std
# 4. 转换维度从 (H,W,C) 到 (C,H,W) - 与PyTorch张量布局一致
image_array = np.transpose(image_array, (2, 0, 1))
return image_array
return transform
transform = make_lvd_transform(224)
def predict_onnx_dino(image):
im_pil = safe_load_img(image)
input_name_model = session_dino3.get_inputs()[0].name
output_name_model = session_dino3.get_outputs()[0].name
return session_dino3.run([output_name_model],
{input_name_model:
np.expand_dims(transform(im_pil), axis=0).astype(np.float32)
}
)[0]
# dinov3结果分类的推理代码及函数
def predict_dino_classify(tokens1,tokens2):
patch_tokens1 = tokens1[:, 5:, :]
patch_tokens2 = tokens2[:, 5:, :]
input_name_model =session_dino_cf.get_inputs()[0].name
output_name_model =session_dino_cf.get_outputs()[0].name
emb1 = session_dino_cf.run([output_name_model], {input_name_model: patch_tokens1})[0]
emb2 = session_dino_cf.run([output_name_model], {input_name_model: patch_tokens2})[0]
emb1_flat = emb1.flatten()
emb2_flat = emb2.flatten()
return float(np.dot(emb1_flat, emb2_flat) / (np.linalg.norm(emb1_flat) * np.linalg.norm(emb2_flat)))
def predict_dino_classify_pipeline(image,draw_result=False):
im_pil = safe_load_img(image)
if draw_result:
draw = ImageDraw.Draw(im_pil)
crops = predict_onnx_yolo(im_pil)
features = {}
for k in crops:
features.update({k:[]})
for v in crops[k]:
features[k].append({"feature":predict_onnx_dino(v['image']),"bbox":v['bbox']})
features["bottom"] = sorted(features["bottom"], key=lambda x: x["bbox"][0])
used_indices = set()
sequence = []
for target in features['bottom']:
available = [(idx, opt) for idx, opt in enumerate(features['top']) if idx not in used_indices]
if not available:
break
if len(available) == 1:
best_idx, best_opt = available[0]
else:
best_idx, best_opt = max(
available,
key=lambda item: predict_dino_classify(target['feature'], item[1]['feature'])
)
sequence.append(best_opt['bbox'])
used_indices.add(best_idx)
colors = ["red", "blue", "green", "yellow", "white", "purple", "orange"]
points = []
for id,one in enumerate(sequence):
center_x = (one[0] + one[2]) / 2
center_y = (one[1] + one[3]) / 2
w = abs(one[0] - one[2])
y = abs(one[1] - one[3])
points.append((center_x+random.randint(int(-w/5),int(w/5)),
center_y+random.randint(int(-y/5),int(y/5))
))
if draw_result:
draw.rectangle(one, outline=colors[id], width=1)
text = f"{id+1}"
text_position = (center_x, center_y)
draw.text(text_position, text, fill='white')
if draw_result:
save_path_temp = os.path.join(validate_path,"icon_result.jpg")
im_pil.save(save_path_temp)
logger.info(f"图片可视化结果暂时保存在{save_path_temp},运行完成后移至{save_path}")
return points
logger.info(f"使用推理设备: {ort.get_device()}")
def use_pdl():
load_model()
def use_dfine():
load_dfine_model()
def use_multi():
load_yolo11n()
load_dinov3()
load_dino_classify()
model_for = [
{"loader":use_pdl,
"include":["session"],
"support":['paddle','pdl','nine','原神','genshin']
},
{"loader":use_dfine,
"include":["session_dfine"],
"support":['dfine','click','memo','note','‌便笺‌']
},
{"loader":use_multi,
"include":['session_yolo11n', 'session_dino3', 'session_dino_cf'],
"support":['multi','dino','click2','星穹铁道','崩铁','绝区零','zzz','hkrpg']
}
]
def get_models():
res = ["以下是当前加载的模型及其对应关键字"]
for key,value in globals().items():
if key.startswith("session") and value is not None:
for one in model_for:
if key in one['include']:
res.append(f" -{key},关键词:{one['support']}")
return res
def get_available_models():
res = ["以下是所有可用模型及其对应关键字"]
for one in model_for:
res.append(f" -{one['include']}关键词:{one['support']}")
return res
def load_by(name):
for one in model_for:
if name in one['support'] or name in one['include']:
one['loader']()
return get_models()
logger.error(f"不支持的名称,可以使用‌便笺‌、原神、崩铁、绝区零表示")
def unload(*names, safe_mode=True):
import gc
protected_vars = {'__name__', '__file__', '__builtins__',
'unload'}
for name in names:
if name in globals():
if safe_mode and name in protected_vars:
logger.error(f"警告: 跳过保护变量 '{name}'")
continue
if not name.startswith('session'):
logger.info("删除的不是模型!")
var = globals()[name]
if hasattr(var, 'close'):
try:
var.close()
except:
pass
globals()[name] = None
logger.info(f"已释放变量: {name}")
collected = gc.collect()
logger.info(f"垃圾回收器清理了 {collected} 个对象")
return get_models()
if int(os.environ.get("use_pdl",1)):
load_model()
use_pdl()
if int(os.environ.get("use_dfine",1)):
load_dfine_model()
use_dfine()
if int(os.environ.get("use_multi",1)):
use_multi()
if __name__ == "__main__":
# 使用resnet18.onnx
# load_model("resnet18.onnx")
@@ -368,5 +771,7 @@ if __name__ == "__main__":
# 使用PP-HGNetV2-B4.onnx
#predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf')
predict_onnx_dfine(r"n:\爬点选\dataset\3f98ff0c91dd4882a8a24d451283ad96.jpg",True)
print(predict_onnx_dfine(r"f:\项目留档\JPEGImages\8bdee494b00d401aae3f496e76d886fc.jpg",True))
# use_multi()
# print(predict_dino_classify_pipeline("0a92e85f89b345279e74deaa9afa9e1c.jpg",True))