mirror of
https://github.com/luguoyixiazi/test_nine.git
synced 2025-12-06 14:52:49 +08:00
增加paddle模型
This commit is contained in:
98
predict.py
98
predict.py
@@ -64,14 +64,15 @@ def predict(icon_image, bg_image):
|
||||
print(largest_three)
|
||||
print("识别完成,耗时:", time.time() - start)
|
||||
|
||||
|
||||
# 加载onnx模型
|
||||
start = time.time()
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(current_dir, 'model', 'resnet18.onnx')
|
||||
session = ort.InferenceSession(model_path)
|
||||
input_name = session.get_inputs()[0].name
|
||||
print("加载模型,耗时:", time.time() - start)
|
||||
def load_model(name='PP-HGNetV2-B4.onnx'):
|
||||
# 加载onnx模型
|
||||
global session,input_name
|
||||
start = time.time()
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(current_dir, 'model', name)
|
||||
session = ort.InferenceSession(model_path)
|
||||
input_name = session.get_inputs()[0].name
|
||||
print("加载模型,耗时:", time.time() - start)
|
||||
|
||||
|
||||
def predict_onnx(icon_image, bg_image, point = None):
|
||||
@@ -145,15 +146,76 @@ def predict_onnx(icon_image, bg_image, point = None):
|
||||
draw_points_on_image(bg_image, answer)
|
||||
return answer
|
||||
|
||||
def predict_onnx_pdl(images_path):
|
||||
coordinates = [
|
||||
[1, 1],
|
||||
[1, 2],
|
||||
[1, 3],
|
||||
[2, 1],
|
||||
[2, 2],
|
||||
[2, 3],
|
||||
[3, 1],
|
||||
[3, 2],
|
||||
[3, 3],
|
||||
]
|
||||
def data_transforms(path):
|
||||
# 打开图片
|
||||
img = Image.open(path)
|
||||
# 调整图片大小为232x224(假设最短边长度调整为232像素)
|
||||
if img.width < img.height:
|
||||
new_size = (232, int(232 * img.height / img.width))
|
||||
else:
|
||||
new_size = (int(232 * img.width / img.height), 232)
|
||||
resized_img = img.resize(new_size, Image.BICUBIC)
|
||||
# 裁剪图片为224x224
|
||||
cropped_img = resized_img.crop((0, 0, 224, 224))
|
||||
# 将图像转换为NumPy数组并进行归一化处理
|
||||
img_array = np.array(cropped_img).astype(np.float32)
|
||||
img_array /= 255.0
|
||||
mean = [0.485, 0.456, 0.406]
|
||||
std = [0.229, 0.224, 0.225]
|
||||
img_array -= np.array(mean)
|
||||
img_array /= np.array(std)
|
||||
# 将通道维度移到前面
|
||||
img_array = np.transpose(img_array, (2, 0, 1))
|
||||
return img_array
|
||||
images = []
|
||||
for pic in sorted(os.listdir(images_path)):
|
||||
if "cropped" not in pic:
|
||||
continue
|
||||
image_path = os.path.join(images_path,pic)
|
||||
images.append(data_transforms(image_path))
|
||||
if len(images) == 0:
|
||||
raise FileNotFoundError(f"先使用切图代码切图至{image_path}再推理,图片命名如cropped_9.jpg,从0到9共十个,最后一个是检测目标")
|
||||
start = time.time()
|
||||
outputs = session.run(None, {input_name: images})[0]
|
||||
result = [np.argmax(one) for one in outputs]
|
||||
target = result[-1]
|
||||
answer = [coordinates[index] for index in range(9) if result[index] == target]
|
||||
print(f"识别完成{answer},耗时: {time.time() - start}")
|
||||
with open(os.path.join(images_path,"nine.jpg"),'rb') as f:
|
||||
bg_image = f.read()
|
||||
draw_points_on_image(bg_image, answer)
|
||||
return answer
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
icon_path = "img_2_val/cropped_9.jpg"
|
||||
bg_path = "img_2_val/nine.jpg"
|
||||
with open(icon_path, "rb") as rb:
|
||||
if icon_path.endswith('.png'):
|
||||
icon_image = convert_png_to_jpg(rb.read())
|
||||
else:
|
||||
icon_image = rb.read()
|
||||
with open(bg_path, "rb") as rb:
|
||||
bg_image = rb.read()
|
||||
predict_onnx(icon_image, bg_image)
|
||||
# 使用resnet18.onnx
|
||||
# load_model("resnet18.onnx")
|
||||
# icon_path = "img_2_val/cropped_9.jpg"
|
||||
# bg_path = "img_2_val/nine.jpg"
|
||||
# with open(icon_path, "rb") as rb:
|
||||
# if icon_path.endswith('.png'):
|
||||
# icon_image = convert_png_to_jpg(rb.read())
|
||||
# else:
|
||||
# icon_image = rb.read()
|
||||
# with open(bg_path, "rb") as rb:
|
||||
# bg_image = rb.read()
|
||||
# predict_onnx(icon_image, bg_image)
|
||||
|
||||
# 使用PP-HGNetV2-B4.onnx
|
||||
load_model()
|
||||
predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf')
|
||||
else:
|
||||
load_model()
|
||||
Reference in New Issue
Block a user