luguoyixiazi
2025-07-02 13:14:05 +08:00
committed by GitHub
parent 657096ee2d
commit 1beae61927
4 changed files with 341 additions and 78 deletions

View File

@@ -1,19 +1,14 @@
import os
import numpy as np
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image
from train import MyResNet18, data_transform
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image,bytes_to_pil,validate_path
import time
import cv2
from PIL import Image
from PIL import Image, ImageDraw
from io import BytesIO
import onnxruntime as ort
def predict(icon_image, bg_image):
from train import MyResNet18, data_transform
import torch
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', 'resnet18_38_0.021147585306924.pth')
@@ -74,10 +69,20 @@ def load_model(name='PP-HGNetV2-B4.onnx'):
model_path = os.path.join(current_dir, 'model', name)
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
print("加载模型,耗时:", time.time() - start)
print(f"加载{name}模型,耗时:{time.time() - start}")
def load_dfine_model(name='d-fine-n.onnx'):
# 加载onnx模型
global session_dfine
start = time.time()
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', name)
session_dfine = ort.InferenceSession(model_path)
print(f"加载{name}模型,耗时:{time.time() - start}")
def predict_onnx(icon_image, bg_image, point = None):
import cv2
coordinates = [
[1, 1],
[1, 2],
@@ -145,7 +150,7 @@ def predict_onnx(icon_image, bg_image, point = None):
else:
answer = [one[0] for one in sorted_arr if one[1] > point]
print(f"识别完成{answer},耗时: {time.time() - start}")
draw_points_on_image(bg_image, answer)
#draw_points_on_image(bg_image, answer)
return answer
def predict_onnx_pdl(images_path):
@@ -194,15 +199,110 @@ def predict_onnx_pdl(images_path):
result = [np.argmax(one) for one in outputs]
target = result[-1]
answer = [coordinates[index] for index in range(9) if result[index] == target]
if len(answer) == 0:
all_sort =[np.argsort(one) for one in outputs]
answer = [coordinates[index] for index in range(9) if all_sort[index][1] == target]
print(f"识别完成{answer},耗时: {time.time() - start}")
if os.path.exists(os.path.join(images_path,"nine.jpg")):
with open(os.path.join(images_path,"nine.jpg"),'rb') as f:
bg_image = f.read()
draw_points_on_image(bg_image, answer)
with open(os.path.join(images_path,"nine.jpg"),'rb') as f:
bg_image = f.read()
draw_points_on_image(bg_image, answer)
return answer
def predict_onnx_dfine(image,draw_result=False):
input_nodes = session_dfine.get_inputs()
output_nodes = session_dfine.get_outputs()
image_input_name = input_nodes[0].name
size_input_name = input_nodes[1].name
output_names = [node.name for node in output_nodes]
if isinstance(image,bytes):
im_pil = bytes_to_pil(image)
else:
im_pil = Image.open(image_path).convert("RGB")
w, h = im_pil.size
orig_size_np = np.array([[w, h]], dtype=np.int64)
im_resized = im_pil.resize((320, 320), Image.Resampling.BILINEAR)
im_data = np.array(im_resized, dtype=np.float32) / 255.0
im_data = im_data.transpose(2, 0, 1)
im_data = np.expand_dims(im_data, axis=0)
inputs = {
image_input_name: im_data,
size_input_name: orig_size_np
}
outputs = session_dfine.run(output_names, inputs)
output_map = {name: data for name, data in zip(output_names, outputs)}
labels = output_map['labels'][0]
boxes = output_map['boxes'][0]
scores = output_map['scores'][0]
colors = ["red", "blue", "green", "yellow", "white", "purple", "orange"]
mask = scores > 0.4
filtered_labels = labels[mask]
filtered_boxes = boxes[mask]
filtered_scores = scores[mask]
rebuild_color = {}
unique_labels = list(set(filtered_labels))
for i, l_val in enumerate(unique_labels):
class_id = int(l_val)
if class_id not in rebuild_color:
rebuild_color[class_id] = colors[i % len(colors)]
result = {k: [] for k in unique_labels}
for i, box in enumerate(filtered_boxes):
label_val = filtered_labels[i]
class_id = int(label_val)
color = rebuild_color[class_id]
score = filtered_scores[i]
result[class_id].append({
'box': box,
'label_val': label_val,
'score': score
})
for class_id in result:
result[class_id].sort(key=lambda item: item['box'][3], reverse=True)
sorted_result = {}
sorted_class_ids = sorted(result.keys(), key=lambda cid: result[cid][0]['box'][0])
for class_id in sorted_class_ids:
sorted_result[class_id] = result[class_id]
points = []
if draw_result:
draw = ImageDraw.Draw(im_pil)
for c1,class_id in enumerate(sorted_result):
items = sorted_result[class_id]
last_item = items[-1]
center_x = (last_item['box'][0] + last_item['box'][2]) / 2
center_y = (last_item['box'][1] + last_item['box'][3]) / 2
text_position_center = (center_x , center_y)
points.append(text_position_center)
if draw_result:
color = rebuild_color[class_id]
draw.point((center_x, center_y), fill=color)
text_center = f"{c1}"
draw.text(text_position_center, text_center, fill=color)
for c2,item in enumerate(items):
box = item['box']
score = item['score']
draw.rectangle(list(box), outline=color, width=1)
text = f"{class_id}_{c1}-{c2}: {score:.2f}"
text_position = (box[0] + 2, box[1] - 12 if box[1] > 12 else box[1] + 2)
draw.text(text_position, text, fill=color)
if draw_result:
save_path = os.path.join(validate_path,"icon_result.jpg")
im_pil.save(save_path)
print(f"图片可视化结果保存在{save_path}")
print(f"图片顺序的中心点{points}")
return points
print(f"使用推理设备: {ort.get_device()}")
if int(os.environ.get("use_pdl",1)):
load_model()
if int(os.environ.get("use_dfine",1)):
load_dfine_model()
if __name__ == "__main__":
# 使用resnet18.onnx
# load_model("resnet18.onnx")
@@ -218,7 +318,5 @@ if __name__ == "__main__":
# predict_onnx(icon_image, bg_image)
# 使用PP-HGNetV2-B4.onnx
load_model()
predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf')
else:
load_model()
#predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf')
predict_onnx_dfine(r"n:\爬点选\dataset\3f98ff0c91dd4882a8a24d451283ad96.jpg",True)