From b4d35b99063e62f1a4f5936ae623cdf9cced4514 Mon Sep 17 00:00:00 2001 From: luguoyixiazi <83990760+luguoyixiazi@users.noreply.github.com> Date: Tue, 4 Nov 2025 16:40:02 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E6=A8=A1=E5=9E=8B=E5=8F=8Afullpage.9.?= =?UTF-8?q?2.0-guwyxh.js=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #5 添加dinov3、任务头及yolo11n用于通过二者的验证码 同时更新fullpage.9.2.0-guwyxh.js中新的常量 ``` json { captcha_token":"2064329542", "tsfq":"xovrayel" } ``` --- crack.py | 96 ++++++++++- main.py | 147 +++++++++++++---- predict.py | 469 +++++++++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 642 insertions(+), 70 deletions(-) diff --git a/crack.py b/crack.py index 15a6737..89be2a5 100644 --- a/crack.py +++ b/crack.py @@ -1,16 +1,19 @@ +import os import hashlib import json import math import random import time - +import logging import httpx +from os import path as PATH +from crop_image import validate_path from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives import padding, serialization from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 -from crop_image import validate_path -from os import path as PATH -import os + +logger = logging.getLogger(__name__) + class Crack: def __init__(self, gt=None, challenge=None): self.pic_path = None @@ -39,6 +42,7 @@ Bm1Zzu+l8nSOqAurgQIDAQAB url = f"https://api.geetest.com/gettype.php?gt={self.gt}" res = self.session.get(url) data = json.loads(res.text[1:-1])["data"] + logger.debug(f"再次获得{data}") return data @staticmethod @@ -358,6 +362,7 @@ Bm1Zzu+l8nSOqAurgQIDAQAB } resp = self.session.get("https://api.geetest.com/get.php", params=params).text data = json.loads(resp[22:-1])["data"] + logger.debug(f"获取cs结果{data}") self.c = data["c"] self.s = data["s"] return data["c"], data["s"] @@ -398,9 +403,79 @@ Bm1Zzu+l8nSOqAurgQIDAQAB ] tt = transform(self.encode_mouse_path(mouse_path, self.c, self.s), self.c, self.s) rp = self.MD5(self.gt + self.challenge + self.s) - temp1 = '''"lang":"zh-cn","type":"fullpage","tt":"%s","light":"DIV_0","s":"c7c3e21112fe4f741921cb3e4ff9f7cb","h":"321f9af1e098233dbd03f250fd2b5e21","hh":"39bd9cad9e425c3a8f51610fd506e3b3","hi":"09eb21b3ae9542a9bc1e8b63b3d9a467","vip_order":-1,"ct":-1,"ep":{"v":"9.1.9-dbjg5z","te":false,"me":true,"ven":"Google Inc. (Intel)","ren":"ANGLE (Intel, Intel(R) Iris(R) Xe Graphics (0x0000A7A0) Direct3D11 vs_5_0 ps_5_0, D3D11)","fp":["scroll",0,1602,1724571628498,null],"lp":["up",386,217,1724571629854,"pointerup"],"em":{"ph":0,"cp":0,"ek":"11","wd":1,"nt":0,"si":0,"sc":0},"tm":{"a":1724571567311,"b":1724571567549,"c":1724571567562,"d":0,"e":0,"f":1724571567312,"g":1724571567312,"h":1724571567312,"i":1724571567317,"j":1724571567423,"k":1724571567330,"l":1724571567423,"m":1724571567545,"n":1724571567547,"o":1724571567569,"p":1724571568259,"q":1724571568259,"r":1724571568261,"s":1724571570378,"t":1724571570378,"u":1724571570380},"dnf":"dnf","by":0},"passtime":1600,"rp":"%s",''' % ( - tt, rp) - r = "{" + temp1 + '"captcha_token":"1198034057","du6o":"eyjf7nne"}' + payload_dict = { + "lang": "zh-cn", + "type": "fullpage", + "tt": tt, # 使用变量 + "light": "DIV_0", + "s": "c7c3e21112fe4f741921cb3e4ff9f7cb", + "h": "321f9af1e098233dbd03f250fd2b5e21", + "hh": "39bd9cad9e425c3a8f51610fd506e3b3", + "hi": "09eb21b3ae9542a9bc1e8b63b3d9a467", + "vip_order": -1, + "ct": -1, + "ep": { + "v": "9.2.0-guwyxh", + "te": False, # JSON 'false' 对应 Python 'False' + "me": True, # JSON 'true' 对应 Python 'True' + "ven": "Google Inc. (Intel)", + "ren": "ANGLE (Intel, Intel(R) Iris(R) Xe Graphics (0x0000A7A0) Direct3D11 vs_5_0 ps_5_0, D3D11)", + "fp": [ + "scroll", + 0, + 1602, + 1724571628498, + None # JSON 'null' 对应 Python 'None' + ], + "lp": [ + "up", + 386, + 217, + 1724571629854, + "pointerup" + ], + "em": { + "ph": 0, + "cp": 0, + "ek": "11", + "wd": 1, + "nt": 0, + "si": 0, + "sc": 0 + }, + "tm": { + "a": 1724571567311, + "b": 1724571567549, + "c": 1724571567562, + "d": 0, + "e": 0, + "f": 1724571567312, + "g": 1724571567312, + "h": 1724571567312, + "i": 1724571567317, + "j": 1724571567423, + "k": 1724571567330, + "l": 1724571567423, + "m": 1724571567545, + "n": 1724571567547, + "o": 1724571567569, + "p": 1724571568259, + "q": 1724571568259, + "r": 1724571568261, + "s": 1724571570378, + "t": 1724571570378, + "u": 1724571570380 + }, + "dnf": "dnf", + "by": 0 + }, + "passtime": 1600, + "rp": rp, # 使用变量 + "captcha_token":"2064329542", + "tsfq":"xovrayel" + } + # r = "{" + temp1 + '"captcha_token":"1198034057","du6o":"eyjf7nne"}' + r = json.dumps(payload_dict, separators=(',', ':')) ct = self.aes_encrypt(r) s = [byte for byte in ct] w = self.encode(s) @@ -414,6 +489,7 @@ Bm1Zzu+l8nSOqAurgQIDAQAB "w": w } resp = self.session.get("https://api.geetest.com/ajax.php", params=params).text + logger.debug(f"ajax结果:{resp}") return json.loads(resp[22:-1])["data"] def get_pic(self): @@ -435,11 +511,15 @@ Bm1Zzu+l8nSOqAurgQIDAQAB } resp = self.session.get("https://api.geevisit.com/get.php", params=params).text data = json.loads(resp[22:-1])["data"] + logger.debug(f"获取图片结果{data}") self.pic_path = data["pic"] pic_url = "https://" + data["resource_servers"][0][:-1] + data["pic"] + pic_data = self.session.get(pic_url).content pic_name = data["pic"].split("/")[-1] pic_type = data["pic_type"] + if "80/2023-12-04T12/icon" in pic_url: + pic_type = "icon1" with open(PATH.join(validate_path,pic_name),'wb+') as f: f.write(pic_data) return pic_data,pic_name,pic_type @@ -483,4 +563,4 @@ Bm1Zzu+l8nSOqAurgQIDAQAB "w": w } resp = self.session.get("https://api.geevisit.com/ajax.php", params=params).text - return resp[1:-1] + return resp[1:-1] \ No newline at end of file diff --git a/main.py b/main.py index 0dee236..5923e09 100644 --- a/main.py +++ b/main.py @@ -4,45 +4,86 @@ import time import httpx import shutil import random +import uvicorn import logging -import logging.config -from crack import Crack -from typing import Optional, Dict, Any -from contextlib import asynccontextmanager -from fastapi.responses import JSONResponse -from fastapi import FastAPI,Query, HTTPException -from predict import predict_onnx,predict_onnx_pdl,predict_onnx_dfine -from crop_image import crop_image_v3,save_path,save_fail_path,save_pass_path,validate_path - -PORT = 9645 -platform = os.name +import logging.handlers +LOG_DIR = "logs" +os.makedirs(LOG_DIR, exist_ok=True) +LOG_FILENAME = os.path.join(LOG_DIR, "app.log") +log_level = os.environ.get("LOG_LEVEL", "INFO").upper() # --- 日志配置字典 --- LOGGING_CONFIG = { "version": 1, "disable_existing_loggers": False, "formatters": { + # 定义一个名为 "default" 的格式化器 "default": { - "()": f"uvicorn.{'_logging' if platform== 'nt' else 'logging'}.DefaultFormatter", - "fmt": "%(levelprefix)s %(asctime)s | %(message)s", + "()": "logging.Formatter", # 使用标准的 Formatter + "fmt": "[%(levelname)s - %(filename)s:%(lineno)d | %(funcName)s] - %(asctime)s - %(message)s", "datefmt": "%Y-%m-%d %H:%M:%S", }, }, "handlers": { - "default": { - "formatter": "default", + # 控制台输出的 Handler + "console": { "class": "logging.StreamHandler", + "formatter": "default", # 使用上面定义的 default 格式 "stream": "ext://sys.stderr", }, + # 文件输出和轮转的 Handler + "file_rotating": { + "class": "logging.handlers.TimedRotatingFileHandler", + "formatter": "default", # 也使用上面定义的 default 格式 + "filename": LOG_FILENAME, # 日志文件路径 + "when": "D", # 按天轮转 ('D' for Day) + "interval": 1, # 每天轮转一次 + "backupCount": 2, # 保留2个旧的日志文件 (加上当前文件,总共覆盖3天) + "encoding": "utf-8", + }, }, "loggers": { - # 将根日志记录器的级别设置为 INFO - "": {"handlers": ["default"], "level": "INFO"}, - "uvicorn.error": {"level": "INFO"}, - "uvicorn.access": {"handlers": ["default"], "level": "INFO", "propagate": False}, + # 根日志记录器 + "": { + # 同时将日志发送到 console 和 file_rotating 两个 Handler + "handlers": ["console", "file_rotating"], + "level": log_level, + }, + # 针对 uvicorn 的日志记录器进行配置,确保它们也使用我们的设置 + "uvicorn": { + "handlers": ["console", "file_rotating"], + "level": "WARNING", + "propagate": False, # 阻止 uvicorn 日志向上传播到根 logger,避免重复记录 + }, + "uvicorn.error": { + "level": "WARNING", + "propagate": True, # uvicorn.error 应该传播,以便根记录器可以捕获它 + }, + "uvicorn.access": { + "handlers": ["console", "file_rotating"], + "level": log_level, + "propagate": False, + }, }, } logging.config.dictConfig(LOGGING_CONFIG) logger = logging.getLogger(__name__) +from crack import Crack +from typing import Optional, Dict, Any +from contextlib import asynccontextmanager +from fastapi.responses import JSONResponse +from fastapi import FastAPI,Query, HTTPException +from predict import (predict_onnx, + predict_onnx_pdl, + predict_onnx_dfine, + predict_dino_classify_pipeline, + load_by, + unload, + get_models, + get_available_models) +from crop_image import crop_image_v3,save_path,save_fail_path,save_pass_path,validate_path + +PORT = 9645 + def get_available_hosts() -> set[str]: """获取本机所有可用的IPv4地址。""" import socket @@ -70,7 +111,7 @@ async def lifespan(app: FastAPI): port = server.config.port if server else PORT if host == "0.0.0.0": available_hosts = get_available_hosts() - logger.info(f"服务地址(依需求选用,docker中使用宿主机host:{port}):") + logger.info(f"服务地址(依需求选用,docker中使用宿主机host:{port},若使用Uvicorn运行则基于命令):") for h in sorted(list(available_hosts)): logger.info(f" - http://{h}:{port}") else: @@ -78,6 +119,7 @@ async def lifespan(app: FastAPI): logger.info(f"可用服务路径如下:") for route in app.routes: logger.info(f" -{route.methods} {route.path}") + logger.info(f"具体api使用可以查看/docs") logger.info("="*50) yield @@ -89,9 +131,9 @@ app = FastAPI(title="极验V3图标点选+九宫格", lifespan=lifespan) def prepare(gt: str, challenge: str) -> tuple[Crack, bytes, str, str]: """获取信息。""" - logging.info(f"开始获取:\ngt:{gt}\nchallenge:{challenge}") + logger.info(f"开始获取:\ngt:{gt}\nchallenge:{challenge}") crack = Crack(gt, challenge) - crack.gettype() + logger.debug(f"初次获得{crack.gettype()}") crack.get_c_s() time.sleep(random.uniform(0.4,0.6)) crack.ajax() @@ -114,7 +156,12 @@ def do_pass_nine(pic_content: bytes, use_v3_model: bool, point: Optional[str]) - def do_pass_icon(pic:Any, draw_result: bool) -> list[str]: """处理图标点选验证码,返回坐标点列表。""" result_list = predict_onnx_dfine(pic,draw_result) - print(result_list) + logger.debug(result_list) + return [f"{round(x / 333 * 10000)}_{round(y / 333 * 10000)}" for x, y in result_list] + +def do_pass_icon0(pic:Any, draw_result: bool) -> list[str]: + """处理图标点选验证码,返回坐标点列表。""" + result_list = predict_dino_classify_pipeline(pic,draw_result) return [f"{round(x / 333 * 10000)}_{round(y / 333 * 10000)}" for x, y in result_list] def save_image_for_train(pic_name,pic_type,passed): @@ -141,14 +188,16 @@ def handle_pass_request(gt: str, challenge: str, save_result: bool, **kwargs) -> crack, pic_content, pic_name, pic_type = prepare(gt, challenge) # 2. 识别 - + logger.debug(f"接收图片类型{pic_type}") if pic_type == "nine": point_list = do_pass_nine( pic_content, use_v3_model=kwargs.get("use_v3_model", True), point=kwargs.get("point",None) ) - elif pic_type == "icon": + elif pic_type == 'icon': # dino + point_list = do_pass_icon0(pic_content, save_result) + elif pic_type == "icon1": # d-fine point_list = do_pass_icon(pic_content, save_result) else: raise HTTPException(status_code=400, detail=f"Unknown picture type: {pic_type}") @@ -169,14 +218,14 @@ def handle_pass_request(gt: str, challenge: str, save_result: bool, **kwargs) -> os.remove(os.path.join(validate_path,pic_name)) total_time = time.monotonic() - start_time - logging.info( + logger.info( f"请求完成,耗时: {total_time:.2f}s (等待 {wait_time:.2f}s). " f"结果: {result}" ) return JSONResponse(content=result) except Exception as e: - logging.error(f"服务错误: {e}", exc_info=True) + logger.error(f"服务错误: {e}", exc_info=True) return JSONResponse( status_code=500, content={"error": "An internal server error occurred.", "detail": str(e)} @@ -227,12 +276,50 @@ def pass_hutao(gt: str = Query(...), return JSONResponse(content=rebuild_content, status_code=original_status_code) except Exception as e: - logging.error(f"修改路由错误: {e}", exc_info=True) + logger.error(f"修改路由错误: {e}", exc_info=True) return JSONResponse( status_code=500, content={"error": "An internal server error occurred.", "detail": str(e)} ) +@app.get("/list_model") +def list_model(): + return JSONResponse(content = get_models()) +@app.get("/list_all_model") +def list_model(): + return JSONResponse(content = get_available_models()) +@app.get("/load_model") +def load_model(name: str = Query(...)): + return JSONResponse(content = get_models()) +@app.get("/unload_model") +def unload_model(name: str = Query(...)): + return JSONResponse(content = get_models()) + +@app.get("/set_log_level") +def set_log_level(level: str = Query(...)): + """ + 在服务运行时动态修改所有主要 logger 的日志级别。 + 例如: /set_log_level?level=DEBUG + """ + # 将字符串级别转换为 logging 模块对应的整数值 + level_str = str(level).upper() + numeric_level = getattr(logging, level_str, None) + if not isinstance(numeric_level, int): + raise HTTPException(status_code=400, detail=f"无效的日志级别: {level}") + + # 获取并修改您配置中所有关键 logger 的级别 + # 1. 修改根 logger + logging.getLogger().setLevel(numeric_level) + # 2. 修改 uvicorn logger + logging.getLogger("uvicorn").setLevel(numeric_level) + logging.getLogger("uvicorn.error").setLevel(numeric_level) + logging.getLogger("uvicorn.access").setLevel(numeric_level) + + # 记录一条高级别的日志,确保能被看到 + logger.warning(f"所有 logger 的日志级别已被动态修改为: {level_str}") + + return JSONResponse(content = f"Log level successfully set to {level_str}") + if __name__ == "__main__": - import uvicorn - uvicorn.run(app,port=PORT) \ No newline at end of file + uvicorn.run(app,port=PORT,log_config=LOGGING_CONFIG) + \ No newline at end of file diff --git a/predict.py b/predict.py index 31528e4..137beea 100644 --- a/predict.py +++ b/predict.py @@ -1,12 +1,40 @@ import os -import numpy as np -from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image,bytes_to_pil,validate_path import time -from PIL import Image, ImageDraw +import base64 +import random +import logging +import numpy as np from io import BytesIO import onnxruntime as ort - - +from PIL import Image, ImageDraw +from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image,bytes_to_pil,validate_path,save_path +logger = logging.getLogger(__name__) +def safe_load_img(image): + im_pil = None + try: + if isinstance(image, Image.Image): + im_pil = image + elif isinstance(image, str): + try: + im_pil = Image.open(image) + except (IOError, FileNotFoundError): + if ',' in image: + image = image.split(',')[-1] + padding = len(image) % 4 + if padding > 0: + image += '=' * (4 - padding) + img_bytes = base64.b64decode(image) + im_pil = Image.open(io.BytesIO(img_bytes)) + elif isinstance(image, bytes): + im_pil = bytes_to_pil(image) + elif isinstance(image, np.ndarray): + im_pil = Image.fromarray(image) + else: + raise ValueError(f"不支持的输入类型: {type(image)}") + return im_pil.convert("RGB") + except Exception as e: + raise ValueError(f"无法加载或解析图像,错误: {e}") + def predict(icon_image, bg_image): import torch from train import MyResNet18, data_transform @@ -34,7 +62,7 @@ def predict(icon_image, bg_image): model = MyResNet18(num_classes=91) # 这里的类别数要与训练时一致 model.load_state_dict(torch.load(model_path)) model.eval() - print("加载模型,耗时:", time.time() - start) + logger.info("加载模型,耗时:", time.time() - start) start = time.time() target_images = torch.stack(target_images, dim=0) @@ -52,14 +80,14 @@ def predict(icon_image, bg_image): ) scores.append(similarity.cpu().item()) # 从左到右,从上到下,依次为每张图片的置信度 - print(scores) + logger.info(scores) # 对数组进行排序,保持下标 indexed_arr = list(enumerate(scores)) sorted_arr = sorted(indexed_arr, key=lambda x: x[1], reverse=True) # 提取最大三个数及其下标 largest_three = sorted_arr[:3] - print(largest_three) - print("识别完成,耗时:", time.time() - start) + logger.info(largest_three) + logger.info("识别完成,耗时:", time.time() - start) def load_model(name='PP-HGNetV2-B4.onnx'): # 加载onnx模型 @@ -69,7 +97,7 @@ def load_model(name='PP-HGNetV2-B4.onnx'): model_path = os.path.join(current_dir, 'model', name) session = ort.InferenceSession(model_path) input_name = session.get_inputs()[0].name - print(f"加载{name}模型,耗时:{time.time() - start}") + logger.info(f"加载{name}模型,耗时:{time.time() - start}") def load_dfine_model(name='d-fine-n.onnx'): # 加载onnx模型 @@ -78,8 +106,31 @@ def load_dfine_model(name='d-fine-n.onnx'): current_dir = os.path.dirname(os.path.abspath(__file__)) model_path = os.path.join(current_dir, 'model', name) session_dfine = ort.InferenceSession(model_path) - print(f"加载{name}模型,耗时:{time.time() - start}") + logger.info(f"加载{name}模型,耗时:{time.time() - start}") +def load_yolo11n(name='yolo11n.onnx'): + global session_yolo11n + start = time.time() + current_dir = os.path.dirname(os.path.abspath(__file__)) + model_path = os.path.join(current_dir, 'model', name) + session_yolo11n = ort.InferenceSession(model_path) + logger.info(f"加载{name}模型,耗时:{time.time() - start}") + +def load_dinov3(name='dinov3-small.onnx'): + global session_dino3 + start = time.time() + current_dir = os.path.dirname(os.path.abspath(__file__)) + model_path = os.path.join(current_dir, 'model', name) + session_dino3 = ort.InferenceSession(model_path) + logger.info(f"加载{name}模型,耗时:{time.time() - start}") + +def load_dino_classify(name='atten.onnx'): + global session_dino_cf + start = time.time() + current_dir = os.path.dirname(os.path.abspath(__file__)) + model_path = os.path.join(current_dir, 'model', name) + session_dino_cf = ort.InferenceSession(model_path) + logger.info(f"加载{name}模型,耗时:{time.time() - start}") def predict_onnx(icon_image, bg_image, point = None): import cv2 @@ -137,8 +188,7 @@ def predict_onnx(icon_image, bg_image, point = None): else: similarity = cosine_similarity(target_output, out_put) scores.append(similarity) - # 从左到右,从上到下,依次为每张图片的置信度 - # print(scores) + logger.debug(f"从左到右,从上到下,依次为每张图片的置信度:\n{scores}") # 对数组进行排序,保持下标 indexed_arr = list(enumerate(scores)) sorted_arr = sorted(indexed_arr, key=lambda x: x[1], reverse=True) @@ -149,7 +199,7 @@ def predict_onnx(icon_image, bg_image, point = None): # 基于分数判断 else: answer = [one[0] for one in sorted_arr if one[1] > point] - print(f"识别完成{answer},耗时: {time.time() - start}") + logger.info(f"识别完成{answer},耗时: {time.time() - start}") #draw_points_on_image(bg_image, answer) return answer @@ -202,40 +252,79 @@ def predict_onnx_pdl(images_path): if len(answer) == 0: all_sort =[np.argsort(one) for one in outputs] answer = [coordinates[index] for index in range(9) if all_sort[index][1] == target] - print(f"识别完成{answer},耗时: {time.time() - start}") + logger.info(f"识别完成{answer},耗时: {time.time() - start}") with open(os.path.join(images_path,"nine.jpg"),'rb') as f: bg_image = f.read() - # draw_points_on_image(bg_image, answer) + #draw_points_on_image(bg_image, answer) return answer - + +# d-fine的推理代码及函数 def calculate_iou(boxA, boxB): + """ + 使用 NumPy 计算两个边界框的交并比 (IoU)。 + """ + # 确定相交矩形的坐标 xA = np.maximum(boxA[0], boxB[0]) yA = np.maximum(boxA[1], boxB[1]) xB = np.minimum(boxA[2], boxB[2]) yB = np.minimum(boxA[3], boxB[3]) + # 计算相交区域的面积 intersection_area = np.maximum(0, xB - xA) * np.maximum(0, yB - yA) + + # 计算两个边界框的面积 boxA_area = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]) boxB_area = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]) + + # 计算并集面积 union_area = float(boxA_area + boxB_area - intersection_area) + + # 计算 IoU if union_area == 0: - return 0.0 + return 0.0 + iou = intersection_area / union_area return iou + def non_maximum_suppression(detections, iou_threshold=0.35): + """ + 对检测结果执行非极大值抑制 (NMS)。 + + 参数: + detections -- 一个列表,其中每个元素是包含 'box', 'score' 的字典。 + 例如: [{'box': [x1, y1, x2, y2], 'score': 0.9, ...}, ...] + iou_threshold -- 一个浮点数,用于判断框是否重叠的 IoU 阈值。 + + 返回: + final_detections -- 经过 NMS 处理后保留下来的检测结果列表。 + """ + # 1. 检查检测结果是否为空 if not detections: return [] + + # 2. 按置信度(score)从高到低对边界框进行排序 + # 我们使用 lambda 函数来指定排序的键 detections.sort(key=lambda x: x['score'], reverse=True) final_detections = [] + + # 3. 循环处理,直到没有检测结果为止 while detections: + # 4. 将当前得分最高的检测结果(第一个)添加到最终列表中 + # 并将其从原始列表中移除 best_detection = detections.pop(0) final_detections.append(best_detection) + + # 5. 计算刚刚取出的最佳框与剩余所有框的 IoU + # 并只保留那些 IoU 小于阈值的框 detections_to_keep = [] for det in detections: + # 假设相同类别的才进行NMS iou = calculate_iou(best_detection['box'], det['box']) if iou < iou_threshold: detections_to_keep.append(det) + + # 用筛选后的列表替换原始列表,进行下一轮迭代 detections = detections_to_keep return final_detections @@ -246,10 +335,7 @@ def predict_onnx_dfine(image,draw_result=False): image_input_name = input_nodes[0].name size_input_name = input_nodes[1].name output_names = [node.name for node in output_nodes] - if isinstance(image,bytes): - im_pil = bytes_to_pil(image) - else: - im_pil = Image.open(image).convert("RGB") + im_pil = safe_load_img(image) w, h = im_pil.size orig_size_np = np.array([[w, h]], dtype=np.int64) im_resized = im_pil.resize((320, 320), Image.Resampling.BILINEAR) @@ -340,18 +426,335 @@ def predict_onnx_dfine(image,draw_result=False): text_position = (box[0] + 2, box[1] - 12 if box[1] > 12 else box[1] + 2) draw.text(text_position, text, fill=color) if draw_result: - save_path = os.path.join(validate_path,"icon_result.jpg") - im_pil.save(save_path) - print(f"图片可视化结果保存在{save_path}") - print(f"图片顺序的中心点{points}") + save_path_temp = os.path.join(validate_path,"icon_result.jpg") + im_pil.save(save_path_temp) + logger.info(f"图片可视化结果暂时保存在{save_path_temp},运行完成后移至{save_path}") + logger.info(f"图片顺序的中心点{points}") return points - -print(f"使用推理设备: {ort.get_device()}") +# yolo的推理代码及函数 +def predict_onnx_yolo(image): + def filter_Detections(results, thresh = 0.5): + results = results[0] + results = results.transpose() + # if model is trained on 1 class only + if len(results[0]) == 5: + # filter out the detections with confidence > thresh + considerable_detections = [detection for detection in results if detection[4] > thresh] + considerable_detections = np.array(considerable_detections) + return considerable_detections + + # if model is trained on multiple classes + else: + A = [] + for detection in results: + + class_id = detection[4:].argmax() + confidence_score = detection[4:].max() + + new_detection = np.append(detection[:4],[class_id,confidence_score]) + + A.append(new_detection) + + A = np.array(A) + + # filter out the detections with confidence > thresh + considerable_detections = [detection for detection in A if detection[-1] > thresh] + considerable_detections = np.array(considerable_detections) + + return considerable_detections + def NMS(boxes, conf_scores, iou_thresh = 0.55): + + # boxes [[x1,y1, x2,y2], [x1,y1, x2,y2], ...] + + x1 = boxes[:,0] + y1 = boxes[:,1] + x2 = boxes[:,2] + y2 = boxes[:,3] + + areas = (x2-x1)*(y2-y1) + + order = conf_scores.argsort() + + keep = [] + keep_confidences = [] + + while len(order) > 0: + idx = order[-1] + A = boxes[idx] + conf = conf_scores[idx] + + order = order[:-1] + + xx1 = np.take(x1, indices= order) + yy1 = np.take(y1, indices= order) + xx2 = np.take(x2, indices= order) + yy2 = np.take(y2, indices= order) + + keep.append(A) + keep_confidences.append(conf) + + # iou = inter/union + + xx1 = np.maximum(x1[idx], xx1) + yy1 = np.maximum(y1[idx], yy1) + xx2 = np.minimum(x2[idx], xx2) + yy2 = np.minimum(y2[idx], yy2) + + w = np.maximum(xx2-xx1, 0) + h = np.maximum(yy2-yy1, 0) + + intersection = w*h + + # union = areaA + other_areas - intesection + other_areas = np.take(areas, indices= order) + union = areas[idx] + other_areas - intersection + + iou = intersection/union + + boleans = iou < iou_thresh + + order = order[boleans] + + # order = [2,0,1] boleans = [True, False, True] + # order = [2,1] + + return keep, keep_confidences + def rescale_back(results,img_w,img_h,imgsz=384): + cx, cy, w, h, class_id, confidence = results[:,0], results[:,1], results[:,2], results[:,3], results[:, 4], results[:,-1] + cx = cx/imgsz * img_w + cy = cy/imgsz * img_h + w = w/imgsz * img_w + h = h/imgsz * img_h + x1 = cx - w/2 + y1 = cy - h/2 + x2 = cx + w/2 + y2 = cy + h/2 + + boxes = np.column_stack((x1, y1, x2, y2, class_id)) + keep, keep_confidences = NMS(boxes,confidence) + return keep, keep_confidences + im_pil = safe_load_img(image) + im_resized = im_pil.resize((384, 384), Image.Resampling.BILINEAR) + im_data = np.array(im_resized, dtype=np.float32) / 255.0 + im_data = im_data.transpose(2, 0, 1) + im_data = np.expand_dims(im_data, axis=0) + res = session_yolo11n.run(None,{"images":im_data}) + results = filter_Detections(res) + rescaled_results, confidences = rescale_back(results,im_pil.size[0],im_pil.size[1]) + images = {"top":[],"bottom":[]} + for r, conf in zip(rescaled_results, confidences): + x1,y1,x2,y2, cls_id = r + cls_id = int(cls_id) + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) + cropped_image = im_pil.crop((x1, y1, x2, y2)) + if cls_id == 0: + images['top'].append({"image":cropped_image,"bbox":[x1, y1, x2, y2]}) + else: + images['bottom'].append({"image":cropped_image,"bbox":[x1, y1, x2, y2]}) + return images + +# dinov3的推理代码及函数 +def make_lvd_transform(resize_size: int = 224): + """ + 返回一个图像预处理函数,功能与PyTorch版本相同 + """ + # 定义标准化参数 + mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) + std = np.array([0.229, 0.224, 0.225], dtype=np.float32) + + def transform(image) -> np.ndarray: + """ + 图像预处理转换 + + Args: + image: PIL Image 或 numpy array (H,W,C) 范围[0,255] + + Returns: + numpy array (C,H,W) 标准化后的float32数组 + """ + # 确保输入是PIL图像 + if isinstance(image, np.ndarray): + image = Image.fromarray(image.astype('uint8')) + + # 1. 调整大小 (使用LANCZOS抗锯齿,对应antialias=True) + image = image.resize((resize_size, resize_size), Image.LANCZOS) + + # 2. 转换为numpy数组并调整数据类型和范围 + # PIL图像转换为numpy数组 (H,W,C) 范围[0,255] + image_array = np.array(image, dtype=np.float32) + + # 如果图像是RGBA,只取RGB通道 + if image_array.shape[-1] == 4: + image_array = image_array[:, :, :3] + + # 缩放到[0,1]范围 (对应scale=True) + image_array /= 255.0 + + # 3. 标准化 + # 注意:PyTorch的Normalize是逐通道进行的 + image_array = (image_array - mean) / std + + # 4. 转换维度从 (H,W,C) 到 (C,H,W) - 与PyTorch张量布局一致 + image_array = np.transpose(image_array, (2, 0, 1)) + + return image_array + + return transform +transform = make_lvd_transform(224) + +def predict_onnx_dino(image): + im_pil = safe_load_img(image) + input_name_model = session_dino3.get_inputs()[0].name + output_name_model = session_dino3.get_outputs()[0].name + return session_dino3.run([output_name_model], + {input_name_model: + np.expand_dims(transform(im_pil), axis=0).astype(np.float32) + } + )[0] + +# dinov3结果分类的推理代码及函数 +def predict_dino_classify(tokens1,tokens2): + patch_tokens1 = tokens1[:, 5:, :] + patch_tokens2 = tokens2[:, 5:, :] + input_name_model =session_dino_cf.get_inputs()[0].name + output_name_model =session_dino_cf.get_outputs()[0].name + emb1 = session_dino_cf.run([output_name_model], {input_name_model: patch_tokens1})[0] + emb2 = session_dino_cf.run([output_name_model], {input_name_model: patch_tokens2})[0] + emb1_flat = emb1.flatten() + emb2_flat = emb2.flatten() + return float(np.dot(emb1_flat, emb2_flat) / (np.linalg.norm(emb1_flat) * np.linalg.norm(emb2_flat))) + +def predict_dino_classify_pipeline(image,draw_result=False): + im_pil = safe_load_img(image) + if draw_result: + draw = ImageDraw.Draw(im_pil) + crops = predict_onnx_yolo(im_pil) + features = {} + for k in crops: + features.update({k:[]}) + for v in crops[k]: + features[k].append({"feature":predict_onnx_dino(v['image']),"bbox":v['bbox']}) + features["bottom"] = sorted(features["bottom"], key=lambda x: x["bbox"][0]) + used_indices = set() + sequence = [] + + for target in features['bottom']: + available = [(idx, opt) for idx, opt in enumerate(features['top']) if idx not in used_indices] + + if not available: + break + + if len(available) == 1: + best_idx, best_opt = available[0] + else: + best_idx, best_opt = max( + available, + key=lambda item: predict_dino_classify(target['feature'], item[1]['feature']) + ) + + sequence.append(best_opt['bbox']) + used_indices.add(best_idx) + colors = ["red", "blue", "green", "yellow", "white", "purple", "orange"] + points = [] + for id,one in enumerate(sequence): + center_x = (one[0] + one[2]) / 2 + center_y = (one[1] + one[3]) / 2 + w = abs(one[0] - one[2]) + y = abs(one[1] - one[3]) + points.append((center_x+random.randint(int(-w/5),int(w/5)), + center_y+random.randint(int(-y/5),int(y/5)) + )) + if draw_result: + draw.rectangle(one, outline=colors[id], width=1) + text = f"{id+1}" + text_position = (center_x, center_y) + draw.text(text_position, text, fill='white') + if draw_result: + save_path_temp = os.path.join(validate_path,"icon_result.jpg") + im_pil.save(save_path_temp) + logger.info(f"图片可视化结果暂时保存在{save_path_temp},运行完成后移至{save_path}") + return points + + +logger.info(f"使用推理设备: {ort.get_device()}") +def use_pdl(): + load_model() + +def use_dfine(): + load_dfine_model() + +def use_multi(): + load_yolo11n() + load_dinov3() + load_dino_classify() + +model_for = [ + {"loader":use_pdl, + "include":["session"], + "support":['paddle','pdl','nine','原神','genshin'] + }, + {"loader":use_dfine, + "include":["session_dfine"], + "support":['dfine','click','memo','note','‌便笺‌'] + }, + {"loader":use_multi, + "include":['session_yolo11n', 'session_dino3', 'session_dino_cf'], + "support":['multi','dino','click2','星穹铁道','崩铁','绝区零','zzz','hkrpg'] + } +] + +def get_models(): + res = ["以下是当前加载的模型及其对应关键字"] + for key,value in globals().items(): + if key.startswith("session") and value is not None: + for one in model_for: + if key in one['include']: + res.append(f" -{key},关键词:{one['support']}") + return res +def get_available_models(): + res = ["以下是所有可用模型及其对应关键字"] + for one in model_for: + res.append(f" -{one['include']}关键词:{one['support']}") + return res + +def load_by(name): + for one in model_for: + if name in one['support'] or name in one['include']: + one['loader']() + return get_models() + logger.error(f"不支持的名称,可以使用‌便笺‌、原神、崩铁、绝区零表示") + +def unload(*names, safe_mode=True): + import gc + protected_vars = {'__name__', '__file__', '__builtins__', + 'unload'} + for name in names: + if name in globals(): + if safe_mode and name in protected_vars: + logger.error(f"警告: 跳过保护变量 '{name}'") + continue + if not name.startswith('session'): + logger.info("删除的不是模型!") + var = globals()[name] + if hasattr(var, 'close'): + try: + var.close() + except: + pass + globals()[name] = None + logger.info(f"已释放变量: {name}") + collected = gc.collect() + logger.info(f"垃圾回收器清理了 {collected} 个对象") + return get_models() + if int(os.environ.get("use_pdl",1)): - load_model() + use_pdl() if int(os.environ.get("use_dfine",1)): - load_dfine_model() + use_dfine() +if int(os.environ.get("use_multi",1)): + use_multi() + if __name__ == "__main__": # 使用resnet18.onnx # load_model("resnet18.onnx") @@ -368,5 +771,7 @@ if __name__ == "__main__": # 使用PP-HGNetV2-B4.onnx #predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf') - predict_onnx_dfine(r"n:\爬点选\dataset\3f98ff0c91dd4882a8a24d451283ad96.jpg",True) - + print(predict_onnx_dfine(r"f:\项目留档\JPEGImages\8bdee494b00d401aae3f496e76d886fc.jpg",True)) + # use_multi() + # print(predict_dino_classify_pipeline("0a92e85f89b345279e74deaa9afa9e1c.jpg",True)) + \ No newline at end of file