Files
test_nine/predict.py
2025-07-22 02:01:50 +08:00

372 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import numpy as np
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image,bytes_to_pil,validate_path
import time
from PIL import Image, ImageDraw
from io import BytesIO
import onnxruntime as ort
def predict(icon_image, bg_image):
import torch
from train import MyResNet18, data_transform
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', 'resnet18_38_0.021147585306924.pth')
coordinates = [
[1, 1],
[1, 2],
[1, 3],
[2, 1],
[2, 2],
[2, 3],
[3, 1],
[3, 2],
[3, 3],
]
target_images = []
target_images.append(data_transform(Image.open(BytesIO(icon_image))))
bg_images = crop_image(bg_image, coordinates)
for bg_image in bg_images:
target_images.append(data_transform(bg_image))
start = time.time()
model = MyResNet18(num_classes=91) # 这里的类别数要与训练时一致
model.load_state_dict(torch.load(model_path))
model.eval()
print("加载模型,耗时:", time.time() - start)
start = time.time()
target_images = torch.stack(target_images, dim=0)
target_outputs = model(target_images)
scores = []
for i, out_put in enumerate(target_outputs):
if i == 0:
# 增加维度,以便于计算
target_output = out_put.unsqueeze(0)
else:
similarity = torch.nn.functional.cosine_similarity(
target_output, out_put.unsqueeze(0)
)
scores.append(similarity.cpu().item())
# 从左到右,从上到下,依次为每张图片的置信度
print(scores)
# 对数组进行排序,保持下标
indexed_arr = list(enumerate(scores))
sorted_arr = sorted(indexed_arr, key=lambda x: x[1], reverse=True)
# 提取最大三个数及其下标
largest_three = sorted_arr[:3]
print(largest_three)
print("识别完成,耗时:", time.time() - start)
def load_model(name='PP-HGNetV2-B4.onnx'):
# 加载onnx模型
global session,input_name
start = time.time()
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', name)
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
print(f"加载{name}模型,耗时:{time.time() - start}")
def load_dfine_model(name='d-fine-n.onnx'):
# 加载onnx模型
global session_dfine
start = time.time()
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', name)
session_dfine = ort.InferenceSession(model_path)
print(f"加载{name}模型,耗时:{time.time() - start}")
def predict_onnx(icon_image, bg_image, point = None):
import cv2
coordinates = [
[1, 1],
[1, 2],
[1, 3],
[2, 1],
[2, 2],
[2, 3],
[3, 1],
[3, 2],
[3, 3],
]
def cosine_similarity(vec1, vec2):
# 将输入转换为 NumPy 数组
vec1 = np.array(vec1)
vec2 = np.array(vec2)
# 计算点积
dot_product = np.dot(vec1, vec2)
# 计算向量的范数
norm_vec1 = np.linalg.norm(vec1)
norm_vec2 = np.linalg.norm(vec2)
# 计算余弦相似度
similarity = dot_product / (norm_vec1 * norm_vec2)
return similarity
def data_transforms(image):
image = image.resize((224, 224))
image = Image.fromarray(cv2.cvtColor(np.array(image), cv2.COLOR_RGBA2RGB))
image_array = np.array(image)
image_array = image_array.astype(np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
image_array = (image_array - mean) / std
image_array = np.transpose(image_array, (2, 0, 1))
# image_array = np.expand_dims(image_array, axis=0)
return image_array
target_images = []
target_images.append(data_transforms(Image.open(BytesIO(icon_image))))
bg_images = crop_image(bg_image, coordinates)
for one in bg_images:
target_images.append(data_transforms(one))
start = time.time()
outputs = session.run(None, {input_name: target_images})[0]
scores = []
for i, out_put in enumerate(outputs):
if i == 0:
target_output = out_put
else:
similarity = cosine_similarity(target_output, out_put)
scores.append(similarity)
# 从左到右,从上到下,依次为每张图片的置信度
# print(scores)
# 对数组进行排序,保持下标
indexed_arr = list(enumerate(scores))
sorted_arr = sorted(indexed_arr, key=lambda x: x[1], reverse=True)
# 提取最大三个数及其下标
if point == None:
largest_three = sorted_arr[:3]
answer = [coordinates[i[0]] for i in largest_three]
# 基于分数判断
else:
answer = [one[0] for one in sorted_arr if one[1] > point]
print(f"识别完成{answer},耗时: {time.time() - start}")
#draw_points_on_image(bg_image, answer)
return answer
def predict_onnx_pdl(images_path):
coordinates = [
[1, 1],
[1, 2],
[1, 3],
[2, 1],
[2, 2],
[2, 3],
[3, 1],
[3, 2],
[3, 3],
]
def data_transforms(path):
# 打开图片
img = Image.open(path)
# 调整图片大小为232x224假设最短边长度调整为232像素
if img.width < img.height:
new_size = (232, int(232 * img.height / img.width))
else:
new_size = (int(232 * img.width / img.height), 232)
resized_img = img.resize(new_size, Image.BICUBIC)
# 裁剪图片为224x224
cropped_img = resized_img.crop((0, 0, 224, 224))
# 将图像转换为NumPy数组并进行归一化处理
img_array = np.array(cropped_img).astype(np.float32)
img_array /= 255.0
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img_array -= np.array(mean)
img_array /= np.array(std)
# 将通道维度移到前面
img_array = np.transpose(img_array, (2, 0, 1))
return img_array
images = []
for pic in sorted(os.listdir(images_path)):
if "cropped" not in pic:
continue
image_path = os.path.join(images_path,pic)
images.append(data_transforms(image_path))
if len(images) == 0:
raise FileNotFoundError(f"先使用切图代码切图至{image_path}再推理,图片命名如cropped_9.jpg,从0到9共十个,最后一个是检测目标")
start = time.time()
outputs = session.run(None, {input_name: images})[0]
result = [np.argmax(one) for one in outputs]
target = result[-1]
answer = [coordinates[index] for index in range(9) if result[index] == target]
if len(answer) == 0:
all_sort =[np.argsort(one) for one in outputs]
answer = [coordinates[index] for index in range(9) if all_sort[index][1] == target]
print(f"识别完成{answer},耗时: {time.time() - start}")
with open(os.path.join(images_path,"nine.jpg"),'rb') as f:
bg_image = f.read()
draw_points_on_image(bg_image, answer)
return answer
def calculate_iou(boxA, boxB):
xA = np.maximum(boxA[0], boxB[0])
yA = np.maximum(boxA[1], boxB[1])
xB = np.minimum(boxA[2], boxB[2])
yB = np.minimum(boxA[3], boxB[3])
intersection_area = np.maximum(0, xB - xA) * np.maximum(0, yB - yA)
boxA_area = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
boxB_area = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
union_area = float(boxA_area + boxB_area - intersection_area)
if union_area == 0:
return 0.0
iou = intersection_area / union_area
return iou
def non_maximum_suppression(detections, iou_threshold=0.35):
if not detections:
return []
detections.sort(key=lambda x: x['score'], reverse=True)
final_detections = []
while detections:
best_detection = detections.pop(0)
final_detections.append(best_detection)
detections_to_keep = []
for det in detections:
iou = calculate_iou(best_detection['box'], det['box'])
if iou < iou_threshold:
detections_to_keep.append(det)
detections = detections_to_keep
return final_detections
def predict_onnx_dfine(image,draw_result=False):
input_nodes = session_dfine.get_inputs()
output_nodes = session_dfine.get_outputs()
image_input_name = input_nodes[0].name
size_input_name = input_nodes[1].name
output_names = [node.name for node in output_nodes]
if isinstance(image,bytes):
im_pil = bytes_to_pil(image)
else:
im_pil = Image.open(image).convert("RGB")
w, h = im_pil.size
orig_size_np = np.array([[w, h]], dtype=np.int64)
im_resized = im_pil.resize((320, 320), Image.Resampling.BILINEAR)
im_data = np.array(im_resized, dtype=np.float32) / 255.0
im_data = im_data.transpose(2, 0, 1)
im_data = np.expand_dims(im_data, axis=0)
inputs = {
image_input_name: im_data,
size_input_name: orig_size_np
}
outputs = session_dfine.run(output_names, inputs)
output_map = {name: data for name, data in zip(output_names, outputs)}
labels = output_map['labels'][0]
boxes = output_map['boxes'][0]
scores = output_map['scores'][0]
colors = ["red", "blue", "green", "yellow", "white", "purple", "orange"]
mask = scores > 0.4
filtered_labels = labels[mask]
filtered_boxes = boxes[mask]
filtered_scores = scores[mask]
rebuild_color = {}
unique_labels = list(set(filtered_labels))
for i, l_val in enumerate(unique_labels):
class_id = int(l_val)
if class_id not in rebuild_color:
rebuild_color[class_id] = colors[i % len(colors)]
result = {k: [] for k in unique_labels}
for i, box in enumerate(filtered_boxes):
if box[2]>160 and box[3] < 45:
continue
label_val = filtered_labels[i]
class_id = int(label_val)
color = rebuild_color[class_id]
score = filtered_scores[i]
result[class_id].append({
'box': box,
'label_val': label_val,
'score': score
})
keep_result = {}
result_points = []
for class_id in result:
tp = non_maximum_suppression(result[class_id],0.01)
if len(tp) < 2:
continue
point = tp[0]["score"]+tp[1]["score"]
if point < 0.85:
continue
keep_result.update({class_id:tp[0:2]})
result_points.append({"id":class_id,"point":point})
result_points.sort(key=lambda item: item['point'], reverse=True)
if len(keep_result) > 3:
tp = {}
for one in result_points[0:3]:
tp.update({one['id']:keep_result[one['id']]})
keep_result = tp
for class_id in keep_result:
keep_result[class_id].sort(key=lambda item: item['box'][3], reverse=True)
sorted_result = {}
sorted_class_ids = sorted(keep_result.keys(), key=lambda cid: keep_result[cid][0]['box'][0])
for class_id in sorted_class_ids:
sorted_result[class_id] = keep_result[class_id]
points = []
if draw_result:
draw = ImageDraw.Draw(im_pil)
for c1,class_id in enumerate(sorted_result):
items = sorted_result[class_id]
last_item = items[-1]
center_x = (last_item['box'][0] + last_item['box'][2]) / 2
center_y = (last_item['box'][1] + last_item['box'][3]) / 2
text_position_center = (center_x , center_y)
points.append(text_position_center)
if draw_result:
color = rebuild_color[class_id]
draw.point((center_x, center_y), fill=color)
text_center = f"{c1}"
draw.text(text_position_center, text_center, fill=color)
for c2,item in enumerate(items):
box = item['box']
score = item['score']
draw.rectangle(list(box), outline=color, width=1)
text = f"{class_id}_{c1}-{c2}: {score:.2f}"
text_position = (box[0] + 2, box[1] - 12 if box[1] > 12 else box[1] + 2)
draw.text(text_position, text, fill=color)
if draw_result:
save_path = os.path.join(validate_path,"icon_result.jpg")
im_pil.save(save_path)
print(f"图片可视化结果保存在{save_path}")
print(f"图片顺序的中心点{points}")
return points
print(f"使用推理设备: {ort.get_device()}")
if int(os.environ.get("use_pdl",1)):
load_model()
if int(os.environ.get("use_dfine",1)):
load_dfine_model()
if __name__ == "__main__":
# 使用resnet18.onnx
# load_model("resnet18.onnx")
# icon_path = "img_2_val/cropped_9.jpg"
# bg_path = "img_2_val/nine.jpg"
# with open(icon_path, "rb") as rb:
# if icon_path.endswith('.png'):
# icon_image = convert_png_to_jpg(rb.read())
# else:
# icon_image = rb.read()
# with open(bg_path, "rb") as rb:
# bg_image = rb.read()
# predict_onnx(icon_image, bg_image)
# 使用PP-HGNetV2-B4.onnx
#predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf')
predict_onnx_dfine(r"n:\爬点选\dataset\3f98ff0c91dd4882a8a24d451283ad96.jpg",True)