Files
test_nine/predict.py
2024-11-02 02:22:29 +08:00

221 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import numpy as np
from train import MyResNet18, data_transform
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image
import torch
import time
import cv2
from PIL import Image
from io import BytesIO
import onnxruntime as ort
def predict(icon_image, bg_image):
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', 'resnet18_38_0.021147585306924.pth')
coordinates = [
[1, 1],
[1, 2],
[1, 3],
[2, 1],
[2, 2],
[2, 3],
[3, 1],
[3, 2],
[3, 3],
]
target_images = []
target_images.append(data_transform(Image.open(BytesIO(icon_image))))
bg_images = crop_image(bg_image, coordinates)
for bg_image in bg_images:
target_images.append(data_transform(bg_image))
start = time.time()
model = MyResNet18(num_classes=91) # 这里的类别数要与训练时一致
model.load_state_dict(torch.load(model_path))
model.eval()
print("加载模型,耗时:", time.time() - start)
start = time.time()
target_images = torch.stack(target_images, dim=0)
target_outputs = model(target_images)
scores = []
for i, out_put in enumerate(target_outputs):
if i == 0:
# 增加维度,以便于计算
target_output = out_put.unsqueeze(0)
else:
similarity = torch.nn.functional.cosine_similarity(
target_output, out_put.unsqueeze(0)
)
scores.append(similarity.cpu().item())
# 从左到右,从上到下,依次为每张图片的置信度
print(scores)
# 对数组进行排序,保持下标
indexed_arr = list(enumerate(scores))
sorted_arr = sorted(indexed_arr, key=lambda x: x[1], reverse=True)
# 提取最大三个数及其下标
largest_three = sorted_arr[:3]
print(largest_three)
print("识别完成,耗时:", time.time() - start)
def load_model(name='PP-HGNetV2-B4.onnx'):
# 加载onnx模型
global session,input_name
start = time.time()
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', name)
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
print("加载模型,耗时:", time.time() - start)
def predict_onnx(icon_image, bg_image, point = None):
coordinates = [
[1, 1],
[1, 2],
[1, 3],
[2, 1],
[2, 2],
[2, 3],
[3, 1],
[3, 2],
[3, 3],
]
def cosine_similarity(vec1, vec2):
# 将输入转换为 NumPy 数组
vec1 = np.array(vec1)
vec2 = np.array(vec2)
# 计算点积
dot_product = np.dot(vec1, vec2)
# 计算向量的范数
norm_vec1 = np.linalg.norm(vec1)
norm_vec2 = np.linalg.norm(vec2)
# 计算余弦相似度
similarity = dot_product / (norm_vec1 * norm_vec2)
return similarity
def data_transforms(image):
image = image.resize((224, 224))
image = Image.fromarray(cv2.cvtColor(np.array(image), cv2.COLOR_RGBA2RGB))
image_array = np.array(image)
image_array = image_array.astype(np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
image_array = (image_array - mean) / std
image_array = np.transpose(image_array, (2, 0, 1))
# image_array = np.expand_dims(image_array, axis=0)
return image_array
target_images = []
target_images.append(data_transforms(Image.open(BytesIO(icon_image))))
bg_images = crop_image(bg_image, coordinates)
for one in bg_images:
target_images.append(data_transforms(one))
start = time.time()
outputs = session.run(None, {input_name: target_images})[0]
scores = []
for i, out_put in enumerate(outputs):
if i == 0:
target_output = out_put
else:
similarity = cosine_similarity(target_output, out_put)
scores.append(similarity)
# 从左到右,从上到下,依次为每张图片的置信度
# print(scores)
# 对数组进行排序,保持下标
indexed_arr = list(enumerate(scores))
sorted_arr = sorted(indexed_arr, key=lambda x: x[1], reverse=True)
# 提取最大三个数及其下标
if point == None:
largest_three = sorted_arr[:3]
answer = [coordinates[i[0]] for i in largest_three]
# 基于分数判断
else:
answer = [one[0] for one in sorted_arr if one[1] > point]
print(f"识别完成{answer},耗时: {time.time() - start}")
draw_points_on_image(bg_image, answer)
return answer
def predict_onnx_pdl(images_path):
coordinates = [
[1, 1],
[1, 2],
[1, 3],
[2, 1],
[2, 2],
[2, 3],
[3, 1],
[3, 2],
[3, 3],
]
def data_transforms(path):
# 打开图片
img = Image.open(path)
# 调整图片大小为232x224假设最短边长度调整为232像素
if img.width < img.height:
new_size = (232, int(232 * img.height / img.width))
else:
new_size = (int(232 * img.width / img.height), 232)
resized_img = img.resize(new_size, Image.BICUBIC)
# 裁剪图片为224x224
cropped_img = resized_img.crop((0, 0, 224, 224))
# 将图像转换为NumPy数组并进行归一化处理
img_array = np.array(cropped_img).astype(np.float32)
img_array /= 255.0
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img_array -= np.array(mean)
img_array /= np.array(std)
# 将通道维度移到前面
img_array = np.transpose(img_array, (2, 0, 1))
return img_array
images = []
for pic in sorted(os.listdir(images_path)):
if "cropped" not in pic:
continue
image_path = os.path.join(images_path,pic)
images.append(data_transforms(image_path))
if len(images) == 0:
raise FileNotFoundError(f"先使用切图代码切图至{image_path}再推理,图片命名如cropped_9.jpg,从0到9共十个,最后一个是检测目标")
start = time.time()
outputs = session.run(None, {input_name: images})[0]
result = [np.argmax(one) for one in outputs]
target = result[-1]
answer = [coordinates[index] for index in range(9) if result[index] == target]
print(f"识别完成{answer},耗时: {time.time() - start}")
with open(os.path.join(images_path,"nine.jpg"),'rb') as f:
bg_image = f.read()
draw_points_on_image(bg_image, answer)
return answer
if __name__ == "__main__":
# 使用resnet18.onnx
# load_model("resnet18.onnx")
# icon_path = "img_2_val/cropped_9.jpg"
# bg_path = "img_2_val/nine.jpg"
# with open(icon_path, "rb") as rb:
# if icon_path.endswith('.png'):
# icon_image = convert_png_to_jpg(rb.read())
# else:
# icon_image = rb.read()
# with open(bg_path, "rb") as rb:
# bg_image = rb.read()
# predict_onnx(icon_image, bg_image)
# 使用PP-HGNetV2-B4.onnx
load_model()
predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf')
else:
load_model()