Compare commits

..

2 Commits

Author SHA1 Message Date
luguoyixiazi
f3d2951f80 优化推理,模型更新,镜像重新打包 2025-07-22 02:01:50 +08:00
luguoyixiazi
64887c6386 优化推理,模型更新,镜像重新打包 2025-07-22 01:37:39 +08:00

View File

@@ -208,6 +208,37 @@ def predict_onnx_pdl(images_path):
draw_points_on_image(bg_image, answer)
return answer
def calculate_iou(boxA, boxB):
xA = np.maximum(boxA[0], boxB[0])
yA = np.maximum(boxA[1], boxB[1])
xB = np.minimum(boxA[2], boxB[2])
yB = np.minimum(boxA[3], boxB[3])
intersection_area = np.maximum(0, xB - xA) * np.maximum(0, yB - yA)
boxA_area = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
boxB_area = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
union_area = float(boxA_area + boxB_area - intersection_area)
if union_area == 0:
return 0.0
iou = intersection_area / union_area
return iou
def non_maximum_suppression(detections, iou_threshold=0.35):
if not detections:
return []
detections.sort(key=lambda x: x['score'], reverse=True)
final_detections = []
while detections:
best_detection = detections.pop(0)
final_detections.append(best_detection)
detections_to_keep = []
for det in detections:
iou = calculate_iou(best_detection['box'], det['box'])
if iou < iou_threshold:
detections_to_keep.append(det)
detections = detections_to_keep
return final_detections
def predict_onnx_dfine(image,draw_result=False):
input_nodes = session_dfine.get_inputs()
@@ -249,6 +280,8 @@ def predict_onnx_dfine(image,draw_result=False):
rebuild_color[class_id] = colors[i % len(colors)]
result = {k: [] for k in unique_labels}
for i, box in enumerate(filtered_boxes):
if box[2]>160 and box[3] < 45:
continue
label_val = filtered_labels[i]
class_id = int(label_val)
color = rebuild_color[class_id]
@@ -260,14 +293,22 @@ def predict_onnx_dfine(image,draw_result=False):
'score': score
})
keep_result = {}
result_points = []
for class_id in result:
tp = result[class_id]
if len(tp) < 2:
continue
tp.sort(key=lambda item: item['score'], reverse=True)
if tp[0]["score"]+tp[1]["score"] < 1.0:
continue
keep_result.update({class_id:tp[0:2]})
tp = non_maximum_suppression(result[class_id],0.01)
if len(tp) < 2:
continue
point = tp[0]["score"]+tp[1]["score"]
if point < 0.85:
continue
keep_result.update({class_id:tp[0:2]})
result_points.append({"id":class_id,"point":point})
result_points.sort(key=lambda item: item['point'], reverse=True)
if len(keep_result) > 3:
tp = {}
for one in result_points[0:3]:
tp.update({one['id']:keep_result[one['id']]})
keep_result = tp
for class_id in keep_result:
keep_result[class_id].sort(key=lambda item: item['box'][3], reverse=True)
sorted_result = {}
@@ -305,7 +346,6 @@ def predict_onnx_dfine(image,draw_result=False):
print(f"图片顺序的中心点{points}")
return points
print(f"使用推理设备: {ort.get_device()}")
if int(os.environ.get("use_pdl",1)):