mirror of
https://github.com/luguoyixiazi/test_nine.git
synced 2025-12-06 14:52:49 +08:00
增加d-fine模型检测V3的图标点选(snap hutao可用)
看见这个老哥做了一份(https://github.com/taskmgr818/geetest-v3-click-server),但是用ddddocr的话就太重了,刚好一直想炼d-fine,就在哈基米2.5pro的帮助下做了数据集生成就开炉了,原文数据加载时做了一些几何变换,但是不适合验证码的框选,所以我把数据集的变换全写在生成代码里面了,效果挺不错的,没细测,挑了几张都完美pass
This commit is contained in:
7
crack.py
7
crack.py
@@ -10,8 +10,8 @@ from cryptography.hazmat.primitives import padding, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
|
||||
from crop_image import validate_path
|
||||
from os import path as PATH
|
||||
|
||||
|
||||
import os
|
||||
os.environ['https_proxy']="http://127.0.0.1:10809"
|
||||
class Crack:
|
||||
def __init__(self, gt=None, challenge=None):
|
||||
self.pic_path = None
|
||||
@@ -440,9 +440,10 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
|
||||
pic_url = "https://" + data["resource_servers"][0][:-1] + data["pic"]
|
||||
pic_data = self.session.get(pic_url).content
|
||||
pic_name = data["pic"].split("/")[-1]
|
||||
pic_type = data["pic_type"]
|
||||
with open(PATH.join(validate_path,pic_name),'wb+') as f:
|
||||
f.write(pic_data)
|
||||
return pic_data,pic_name
|
||||
return pic_data,pic_name,pic_type
|
||||
|
||||
def verify(self, points: list):
|
||||
u = self.enc_key
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from PIL import Image, ImageFont, ImageDraw, ImageOps
|
||||
from io import BytesIO
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
current_path = os.getcwd()
|
||||
@@ -12,8 +11,8 @@ os.makedirs(validate_path,exist_ok=True)
|
||||
os.makedirs(save_path,exist_ok=True)
|
||||
os.makedirs(save_pass_path,exist_ok=True)
|
||||
os.makedirs(save_fail_path,exist_ok=True)
|
||||
|
||||
def draw_points_on_image(bg_image, answer):
|
||||
import cv2
|
||||
# 将背景图片转换为OpenCV格式
|
||||
bg_image_cv = cv2.imdecode(np.frombuffer(bg_image, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
@@ -53,6 +52,10 @@ def convert_png_to_jpg(png_bytes: bytes) -> bytes:
|
||||
# 返回保存后的 JPG 图像的 bytes
|
||||
return output_bytes.getvalue()
|
||||
|
||||
def bytes_to_pil(image_bytes):
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
image = image.convert('RGB')
|
||||
return image
|
||||
|
||||
def crop_image(image_bytes, coordinates):
|
||||
img = Image.open(BytesIO(image_bytes))
|
||||
@@ -84,8 +87,7 @@ def crop_image_v3(image_bytes):
|
||||
[[232, 232], [344, 344]],#第三行
|
||||
[[2, 344], [42, 384]] #要验证的
|
||||
]
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
image = Image.fromarray(cv2.cvtColor(np.array(image), cv2.COLOR_RGBA2RGB))
|
||||
image = bytes_to_pil(image_bytes)
|
||||
imageNew = Image.new('RGB', (300,261),(0,0,0))
|
||||
images = []
|
||||
for i, (start_point, end_point) in enumerate(coordinates):
|
||||
|
||||
252
main.py
252
main.py
@@ -1,43 +1,105 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
from crack import Crack
|
||||
from crop_image import crop_image_v3,save_path,save_fail_path,save_pass_path,validate_path
|
||||
import httpx
|
||||
from fastapi import FastAPI,Query
|
||||
from fastapi.responses import JSONResponse
|
||||
import shutil
|
||||
import os
|
||||
import random
|
||||
import logging
|
||||
import logging.config
|
||||
from crack import Crack
|
||||
from typing import Optional, Dict, Any
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import FastAPI,Query, HTTPException
|
||||
from predict import predict_onnx,predict_onnx_pdl,predict_onnx_dfine
|
||||
from crop_image import crop_image_v3,save_path,save_fail_path,save_pass_path,validate_path
|
||||
|
||||
port = 9645
|
||||
# api
|
||||
app = FastAPI()
|
||||
PORT = 9645
|
||||
# --- 日志配置字典 ---
|
||||
LOGGING_CONFIG = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {
|
||||
"()": "uvicorn._logging.DefaultFormatter",
|
||||
"fmt": "%(levelprefix)s %(asctime)s | %(message)s",
|
||||
"datefmt": "%Y-%m-%d %H:%M:%S",
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"default": {
|
||||
"formatter": "default",
|
||||
"class": "logging.StreamHandler",
|
||||
"stream": "ext://sys.stderr",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
# 将根日志记录器的级别设置为 INFO
|
||||
"": {"handlers": ["default"], "level": "INFO"},
|
||||
"uvicorn.error": {"level": "INFO"},
|
||||
"uvicorn.access": {"handlers": ["default"], "level": "INFO", "propagate": False},
|
||||
},
|
||||
}
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
logger = logging.getLogger(__name__)
|
||||
def get_available_hosts() -> set[str]:
|
||||
"""获取本机所有可用的IPv4地址。"""
|
||||
import socket
|
||||
hosts = {"127.0.0.1"}
|
||||
try:
|
||||
hostname = socket.gethostname()
|
||||
addr_info = socket.getaddrinfo(hostname, None, socket.AF_INET)
|
||||
hosts.update({info[4][0] for info in addr_info})
|
||||
except socket.gaierror:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(("8.8.8.8", 80))
|
||||
hosts.add(s.getsockname()[0])
|
||||
except OSError:
|
||||
pass
|
||||
return hosts
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("="*50)
|
||||
logger.info("启动服务中...")
|
||||
# 从 uvicorn 配置中获取 host 和 port
|
||||
server = app.servers[0] if app.servers else None
|
||||
host = server.config.host if server else "0.0.0.0"
|
||||
port = server.config.port if server else PORT
|
||||
if host == "0.0.0.0":
|
||||
available_hosts = get_available_hosts()
|
||||
logger.info(f"服务地址(依需求选用,docker中使用宿主机host:{port}):")
|
||||
for h in sorted(list(available_hosts)):
|
||||
logger.info(f" - http://{h}:{port}")
|
||||
else:
|
||||
logger.info(f"服务地址: http://{host}:{port}")
|
||||
logger.info(f"可用服务路径如下:")
|
||||
for route in app.routes:
|
||||
logger.info(f" -{route.methods} {route.path}")
|
||||
logger.info("="*50)
|
||||
|
||||
@app.get("/pass_nine")
|
||||
def get_pic(gt: str = Query(...),
|
||||
challenge: str = Query(...),
|
||||
point: str = Query(default=None),
|
||||
use_v3_model = Query(default=True),
|
||||
save_result = Query(default=False)
|
||||
):
|
||||
print(f"开始获取:\ngt:{gt}\nchallenge:{challenge}")
|
||||
t = time.time()
|
||||
yield
|
||||
logger.info("="*50)
|
||||
logger.info("服务关闭")
|
||||
logger.info("="*50)
|
||||
|
||||
app = FastAPI(title="极验V3图标点选+九宫格", lifespan=lifespan)
|
||||
|
||||
def prepare(gt: str, challenge: str) -> tuple[Crack, bytes, str, str]:
|
||||
"""获取信息。"""
|
||||
logging.info(f"开始获取:\ngt:{gt}\nchallenge:{challenge}")
|
||||
crack = Crack(gt, challenge)
|
||||
|
||||
crack.gettype()
|
||||
|
||||
crack.get_c_s()
|
||||
|
||||
time.sleep(random.uniform(0.4,0.6))
|
||||
|
||||
crack.ajax()
|
||||
pic_content,pic_name,pic_type = crack.get_pic()
|
||||
return crack,pic_content,pic_name,pic_type
|
||||
|
||||
pic_content,pic_name = crack.get_pic()
|
||||
|
||||
def do_pass_nine(pic_content: bytes, use_v3_model: bool, point: Optional[str]) -> list[str]:
|
||||
"""处理九宫格验证码,返回坐标点列表。"""
|
||||
crop_image_v3(pic_content)
|
||||
|
||||
if use_v3_model:
|
||||
result_list = predict_onnx_pdl(validate_path)
|
||||
else:
|
||||
@@ -46,30 +108,130 @@ def get_pic(gt: str = Query(...),
|
||||
with open(f"{validate_path}/nine.jpg", "rb") as rb:
|
||||
bg_image = rb.read()
|
||||
result_list = predict_onnx(icon_image, bg_image, point)
|
||||
return [f"{col}_{row}" for row, col in result_list]
|
||||
|
||||
point_list = [f"{col}_{row}" for row, col in result_list]
|
||||
wait_time = max(0,4.0 - (time.time() - t))
|
||||
time.sleep(wait_time)
|
||||
result = json.loads(crack.verify(point_list))
|
||||
if save_result:
|
||||
shutil.move(os.path.join(validate_path,pic_name),os.path.join(save_path,pic_name))
|
||||
if 'validate' in result['data']:
|
||||
path_2_save = os.path.join(save_pass_path,pic_name.split('.')[0])
|
||||
def do_pass_icon(pic:Any, draw_result: bool) -> list[str]:
|
||||
"""处理图标点选验证码,返回坐标点列表。"""
|
||||
result_list = predict_onnx_dfine(pic,draw_result)
|
||||
print(result_list)
|
||||
return [f"{round(x / 333 * 10000)}_{round(y / 333 * 10000)}" for x, y in result_list]
|
||||
|
||||
def save_image_for_train(pic_name,pic_type,passed):
|
||||
shutil.move(os.path.join(validate_path,pic_name),os.path.join(save_path,pic_name))
|
||||
if passed:
|
||||
path_2_save = os.path.join(save_pass_path,pic_name.split('.')[0])
|
||||
else:
|
||||
path_2_save = os.path.join(save_fail_path,pic_name.split('.')[0])
|
||||
os.makedirs(path_2_save,exist_ok=True)
|
||||
for pic in os.listdir(validate_path):
|
||||
if pic_type == "nine" and pic.startswith('cropped'):
|
||||
shutil.move(os.path.join(validate_path,pic),os.path.join(path_2_save,pic))
|
||||
if pic_type == "icon" and pic.startswith('icon'):
|
||||
shutil.move(os.path.join(validate_path,pic),os.path.join(path_2_save,pic))
|
||||
|
||||
|
||||
def handle_pass_request(gt: str, challenge: str, save_result: bool, **kwargs) -> JSONResponse:
|
||||
"""
|
||||
统一处理所有验证码请求的核心函数。
|
||||
"""
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
# 1. 准备
|
||||
crack, pic_content, pic_name, pic_type = prepare(gt, challenge)
|
||||
|
||||
# 2. 识别
|
||||
|
||||
if pic_type == "nine":
|
||||
point_list = do_pass_nine(
|
||||
pic_content,
|
||||
use_v3_model=kwargs.get("use_v3_model", True),
|
||||
point=kwargs.get("point",None)
|
||||
)
|
||||
elif pic_type == "icon":
|
||||
point_list = do_pass_icon(pic_content, save_result)
|
||||
else:
|
||||
path_2_save = os.path.join(save_fail_path,pic_name.split('.')[0])
|
||||
os.makedirs(path_2_save,exist_ok=True)
|
||||
for pic in os.listdir(validate_path):
|
||||
if pic.startswith('cropped'):
|
||||
shutil.move(os.path.join(validate_path,pic),os.path.join(path_2_save,pic))
|
||||
total_time = time.time() - t
|
||||
print(f"总计耗时(含等待{wait_time}s): {total_time}\n{result}")
|
||||
return JSONResponse(content=result)
|
||||
raise HTTPException(status_code=400, detail=f"Unknown picture type: {pic_type}")
|
||||
|
||||
# 3. 验证
|
||||
elapsed = time.monotonic() - start_time
|
||||
wait_time = max(0, 4.0 - elapsed)
|
||||
time.sleep(wait_time)
|
||||
|
||||
response_str = crack.verify(point_list)
|
||||
result = json.loads(response_str)
|
||||
|
||||
# 4. 后处理
|
||||
passed = 'validate' in result.get('data', {})
|
||||
if save_result:
|
||||
save_image_for_train(pic_name, pic_type, passed)
|
||||
else:
|
||||
os.remove(os.path.join(validate_path,pic_name))
|
||||
|
||||
total_time = time.monotonic() - start_time
|
||||
logging.info(
|
||||
f"请求完成,耗时: {total_time:.2f}s (等待 {wait_time:.2f}s). "
|
||||
f"结果: {result}"
|
||||
)
|
||||
return JSONResponse(content=result)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"服务错误: {e}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": "An internal server error occurred.", "detail": str(e)}
|
||||
)
|
||||
|
||||
|
||||
|
||||
@app.get("/pass_nine")
|
||||
def pass_nine(gt: str = Query(...),
|
||||
challenge: str = Query(...),
|
||||
point: str = Query(default=None),
|
||||
use_v3_model = Query(default=True),
|
||||
save_result = Query(default=False)
|
||||
):
|
||||
return handle_pass_request(
|
||||
gt, challenge, save_result,
|
||||
use_v3_model=use_v3_model, point=point
|
||||
)
|
||||
|
||||
@app.get("/pass_icon")
|
||||
def pass_icon(gt: str = Query(...),
|
||||
challenge: str = Query(...),
|
||||
save_result = Query(default=False)
|
||||
):
|
||||
return handle_pass_request(gt, challenge, save_result)
|
||||
|
||||
@app.get("/pass_uni")
|
||||
def pass_uni(gt: str = Query(...),
|
||||
challenge: str = Query(...),
|
||||
save_result = Query(default=False)
|
||||
):
|
||||
return handle_pass_request(gt, challenge, save_result)
|
||||
|
||||
@app.get("/pass_hutao")
|
||||
def pass_hutao(gt: str = Query(...),
|
||||
challenge: str = Query(...),
|
||||
save_result = Query(default=False)):
|
||||
try:
|
||||
# 调用原函数获取返回值
|
||||
response = handle_pass_request(gt, challenge, save_result)
|
||||
# 获取原始状态码和内容
|
||||
original_status_code = response.status_code
|
||||
original_content = json.loads(response.body.decode("utf-8"))
|
||||
if original_status_code == 200 and original_content.get("status",False)=="success" and "validate" in original_content.get("data",{}):
|
||||
rebuild_content = {"code":0,"data":{"gt":gt,"challenge":challenge,"validate":original_content["data"]["validate"]}}
|
||||
else:
|
||||
rebuild_content = {"code":1,"data":{"gt":gt,"challenge":challenge,"validate":original_content}}
|
||||
return JSONResponse(content=rebuild_content, status_code=original_status_code)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"修改路由错误: {e}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"error": "An internal server error occurred.", "detail": str(e)}
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from predict import predict_onnx,predict_onnx_pdl
|
||||
import uvicorn
|
||||
print(f"{' '*10}api: http://0.0.0.0:{port}/pass_nine{' '*10}")
|
||||
print(f"{' '*10}api所需参数:gt、challenge、point(可选){' '*10}")
|
||||
uvicorn.run(app,host="0.0.0.0",port=port)
|
||||
uvicorn.run(app,port=PORT)
|
||||
134
predict.py
134
predict.py
@@ -1,19 +1,14 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image
|
||||
|
||||
from train import MyResNet18, data_transform
|
||||
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image,bytes_to_pil,validate_path
|
||||
import time
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageDraw
|
||||
from io import BytesIO
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
def predict(icon_image, bg_image):
|
||||
from train import MyResNet18, data_transform
|
||||
import torch
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(current_dir, 'model', 'resnet18_38_0.021147585306924.pth')
|
||||
@@ -74,10 +69,20 @@ def load_model(name='PP-HGNetV2-B4.onnx'):
|
||||
model_path = os.path.join(current_dir, 'model', name)
|
||||
session = ort.InferenceSession(model_path)
|
||||
input_name = session.get_inputs()[0].name
|
||||
print("加载模型,耗时:", time.time() - start)
|
||||
print(f"加载{name}模型,耗时:{time.time() - start}")
|
||||
|
||||
def load_dfine_model(name='d-fine-n.onnx'):
|
||||
# 加载onnx模型
|
||||
global session_dfine
|
||||
start = time.time()
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(current_dir, 'model', name)
|
||||
session_dfine = ort.InferenceSession(model_path)
|
||||
print(f"加载{name}模型,耗时:{time.time() - start}")
|
||||
|
||||
|
||||
def predict_onnx(icon_image, bg_image, point = None):
|
||||
import cv2
|
||||
coordinates = [
|
||||
[1, 1],
|
||||
[1, 2],
|
||||
@@ -145,7 +150,7 @@ def predict_onnx(icon_image, bg_image, point = None):
|
||||
else:
|
||||
answer = [one[0] for one in sorted_arr if one[1] > point]
|
||||
print(f"识别完成{answer},耗时: {time.time() - start}")
|
||||
draw_points_on_image(bg_image, answer)
|
||||
#draw_points_on_image(bg_image, answer)
|
||||
return answer
|
||||
|
||||
def predict_onnx_pdl(images_path):
|
||||
@@ -194,15 +199,110 @@ def predict_onnx_pdl(images_path):
|
||||
result = [np.argmax(one) for one in outputs]
|
||||
target = result[-1]
|
||||
answer = [coordinates[index] for index in range(9) if result[index] == target]
|
||||
if len(answer) == 0:
|
||||
all_sort =[np.argsort(one) for one in outputs]
|
||||
answer = [coordinates[index] for index in range(9) if all_sort[index][1] == target]
|
||||
print(f"识别完成{answer},耗时: {time.time() - start}")
|
||||
if os.path.exists(os.path.join(images_path,"nine.jpg")):
|
||||
with open(os.path.join(images_path,"nine.jpg"),'rb') as f:
|
||||
bg_image = f.read()
|
||||
draw_points_on_image(bg_image, answer)
|
||||
with open(os.path.join(images_path,"nine.jpg"),'rb') as f:
|
||||
bg_image = f.read()
|
||||
draw_points_on_image(bg_image, answer)
|
||||
return answer
|
||||
|
||||
|
||||
def predict_onnx_dfine(image,draw_result=False):
|
||||
input_nodes = session_dfine.get_inputs()
|
||||
output_nodes = session_dfine.get_outputs()
|
||||
image_input_name = input_nodes[0].name
|
||||
size_input_name = input_nodes[1].name
|
||||
output_names = [node.name for node in output_nodes]
|
||||
if isinstance(image,bytes):
|
||||
im_pil = bytes_to_pil(image)
|
||||
else:
|
||||
im_pil = Image.open(image_path).convert("RGB")
|
||||
w, h = im_pil.size
|
||||
orig_size_np = np.array([[w, h]], dtype=np.int64)
|
||||
im_resized = im_pil.resize((320, 320), Image.Resampling.BILINEAR)
|
||||
im_data = np.array(im_resized, dtype=np.float32) / 255.0
|
||||
im_data = im_data.transpose(2, 0, 1)
|
||||
im_data = np.expand_dims(im_data, axis=0)
|
||||
inputs = {
|
||||
image_input_name: im_data,
|
||||
size_input_name: orig_size_np
|
||||
}
|
||||
outputs = session_dfine.run(output_names, inputs)
|
||||
output_map = {name: data for name, data in zip(output_names, outputs)}
|
||||
labels = output_map['labels'][0]
|
||||
boxes = output_map['boxes'][0]
|
||||
scores = output_map['scores'][0]
|
||||
|
||||
colors = ["red", "blue", "green", "yellow", "white", "purple", "orange"]
|
||||
mask = scores > 0.4
|
||||
filtered_labels = labels[mask]
|
||||
filtered_boxes = boxes[mask]
|
||||
filtered_scores = scores[mask]
|
||||
|
||||
rebuild_color = {}
|
||||
unique_labels = list(set(filtered_labels))
|
||||
for i, l_val in enumerate(unique_labels):
|
||||
class_id = int(l_val)
|
||||
if class_id not in rebuild_color:
|
||||
rebuild_color[class_id] = colors[i % len(colors)]
|
||||
|
||||
result = {k: [] for k in unique_labels}
|
||||
for i, box in enumerate(filtered_boxes):
|
||||
label_val = filtered_labels[i]
|
||||
class_id = int(label_val)
|
||||
color = rebuild_color[class_id]
|
||||
score = filtered_scores[i]
|
||||
|
||||
result[class_id].append({
|
||||
'box': box,
|
||||
'label_val': label_val,
|
||||
'score': score
|
||||
})
|
||||
for class_id in result:
|
||||
result[class_id].sort(key=lambda item: item['box'][3], reverse=True)
|
||||
sorted_result = {}
|
||||
sorted_class_ids = sorted(result.keys(), key=lambda cid: result[cid][0]['box'][0])
|
||||
for class_id in sorted_class_ids:
|
||||
sorted_result[class_id] = result[class_id]
|
||||
points = []
|
||||
if draw_result:
|
||||
draw = ImageDraw.Draw(im_pil)
|
||||
for c1,class_id in enumerate(sorted_result):
|
||||
items = sorted_result[class_id]
|
||||
last_item = items[-1]
|
||||
center_x = (last_item['box'][0] + last_item['box'][2]) / 2
|
||||
center_y = (last_item['box'][1] + last_item['box'][3]) / 2
|
||||
text_position_center = (center_x , center_y)
|
||||
points.append(text_position_center)
|
||||
if draw_result:
|
||||
color = rebuild_color[class_id]
|
||||
draw.point((center_x, center_y), fill=color)
|
||||
text_center = f"{c1}"
|
||||
draw.text(text_position_center, text_center, fill=color)
|
||||
for c2,item in enumerate(items):
|
||||
box = item['box']
|
||||
score = item['score']
|
||||
|
||||
draw.rectangle(list(box), outline=color, width=1)
|
||||
text = f"{class_id}_{c1}-{c2}: {score:.2f}"
|
||||
text_position = (box[0] + 2, box[1] - 12 if box[1] > 12 else box[1] + 2)
|
||||
draw.text(text_position, text, fill=color)
|
||||
if draw_result:
|
||||
save_path = os.path.join(validate_path,"icon_result.jpg")
|
||||
im_pil.save(save_path)
|
||||
print(f"图片可视化结果保存在{save_path}")
|
||||
print(f"图片顺序的中心点{points}")
|
||||
return points
|
||||
|
||||
|
||||
|
||||
print(f"使用推理设备: {ort.get_device()}")
|
||||
if int(os.environ.get("use_pdl",1)):
|
||||
load_model()
|
||||
if int(os.environ.get("use_dfine",1)):
|
||||
load_dfine_model()
|
||||
if __name__ == "__main__":
|
||||
# 使用resnet18.onnx
|
||||
# load_model("resnet18.onnx")
|
||||
@@ -218,7 +318,5 @@ if __name__ == "__main__":
|
||||
# predict_onnx(icon_image, bg_image)
|
||||
|
||||
# 使用PP-HGNetV2-B4.onnx
|
||||
load_model()
|
||||
predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf')
|
||||
else:
|
||||
load_model()
|
||||
#predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf')
|
||||
predict_onnx_dfine(r"n:\爬点选\dataset\3f98ff0c91dd4882a8a24d451283ad96.jpg",True)
|
||||
|
||||
Reference in New Issue
Block a user