优化d-fine推理正确率

This commit is contained in:
luguoyixiazi
2025-07-05 01:27:41 +08:00
committed by GitHub
parent 645f5fbb1d
commit dc788adc05

View File

@@ -218,7 +218,7 @@ def predict_onnx_dfine(image,draw_result=False):
if isinstance(image,bytes):
im_pil = bytes_to_pil(image)
else:
im_pil = Image.open(image_path).convert("RGB")
im_pil = Image.open(image).convert("RGB")
w, h = im_pil.size
orig_size_np = np.array([[w, h]], dtype=np.int64)
im_resized = im_pil.resize((320, 320), Image.Resampling.BILINEAR)
@@ -247,7 +247,6 @@ def predict_onnx_dfine(image,draw_result=False):
class_id = int(l_val)
if class_id not in rebuild_color:
rebuild_color[class_id] = colors[i % len(colors)]
result = {k: [] for k in unique_labels}
for i, box in enumerate(filtered_boxes):
label_val = filtered_labels[i]
@@ -260,13 +259,23 @@ def predict_onnx_dfine(image,draw_result=False):
'label_val': label_val,
'score': score
})
keep_result = {}
for class_id in result:
result[class_id].sort(key=lambda item: item['box'][3], reverse=True)
tp = result[class_id]
if len(tp) < 2:
continue
tp.sort(key=lambda item: item['score'], reverse=True)
if tp[0]["score"]+tp[1]["score"] < 1.0:
continue
keep_result.update({class_id:tp[0:2]})
for class_id in keep_result:
keep_result[class_id].sort(key=lambda item: item['box'][3], reverse=True)
sorted_result = {}
sorted_class_ids = sorted(result.keys(), key=lambda cid: result[cid][0]['box'][0])
sorted_class_ids = sorted(keep_result.keys(), key=lambda cid: keep_result[cid][0]['box'][0])
for class_id in sorted_class_ids:
sorted_result[class_id] = result[class_id]
sorted_result[class_id] = keep_result[class_id]
points = []
if draw_result:
draw = ImageDraw.Draw(im_pil)
for c1,class_id in enumerate(sorted_result):