mirror of
https://github.com/luguoyixiazi/test_nine.git
synced 2025-12-06 14:52:49 +08:00
Compare commits
2 Commits
dc788adc05
...
f3d2951f80
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f3d2951f80 | ||
|
|
64887c6386 |
56
predict.py
56
predict.py
@@ -208,6 +208,37 @@ def predict_onnx_pdl(images_path):
|
||||
draw_points_on_image(bg_image, answer)
|
||||
return answer
|
||||
|
||||
def calculate_iou(boxA, boxB):
|
||||
xA = np.maximum(boxA[0], boxB[0])
|
||||
yA = np.maximum(boxA[1], boxB[1])
|
||||
xB = np.minimum(boxA[2], boxB[2])
|
||||
yB = np.minimum(boxA[3], boxB[3])
|
||||
|
||||
intersection_area = np.maximum(0, xB - xA) * np.maximum(0, yB - yA)
|
||||
boxA_area = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
|
||||
boxB_area = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
|
||||
union_area = float(boxA_area + boxB_area - intersection_area)
|
||||
if union_area == 0:
|
||||
return 0.0
|
||||
iou = intersection_area / union_area
|
||||
return iou
|
||||
def non_maximum_suppression(detections, iou_threshold=0.35):
|
||||
if not detections:
|
||||
return []
|
||||
detections.sort(key=lambda x: x['score'], reverse=True)
|
||||
|
||||
final_detections = []
|
||||
while detections:
|
||||
best_detection = detections.pop(0)
|
||||
final_detections.append(best_detection)
|
||||
detections_to_keep = []
|
||||
for det in detections:
|
||||
iou = calculate_iou(best_detection['box'], det['box'])
|
||||
if iou < iou_threshold:
|
||||
detections_to_keep.append(det)
|
||||
detections = detections_to_keep
|
||||
|
||||
return final_detections
|
||||
|
||||
def predict_onnx_dfine(image,draw_result=False):
|
||||
input_nodes = session_dfine.get_inputs()
|
||||
@@ -249,6 +280,8 @@ def predict_onnx_dfine(image,draw_result=False):
|
||||
rebuild_color[class_id] = colors[i % len(colors)]
|
||||
result = {k: [] for k in unique_labels}
|
||||
for i, box in enumerate(filtered_boxes):
|
||||
if box[2]>160 and box[3] < 45:
|
||||
continue
|
||||
label_val = filtered_labels[i]
|
||||
class_id = int(label_val)
|
||||
color = rebuild_color[class_id]
|
||||
@@ -260,14 +293,22 @@ def predict_onnx_dfine(image,draw_result=False):
|
||||
'score': score
|
||||
})
|
||||
keep_result = {}
|
||||
result_points = []
|
||||
for class_id in result:
|
||||
tp = result[class_id]
|
||||
if len(tp) < 2:
|
||||
continue
|
||||
tp.sort(key=lambda item: item['score'], reverse=True)
|
||||
if tp[0]["score"]+tp[1]["score"] < 1.0:
|
||||
continue
|
||||
keep_result.update({class_id:tp[0:2]})
|
||||
tp = non_maximum_suppression(result[class_id],0.01)
|
||||
if len(tp) < 2:
|
||||
continue
|
||||
point = tp[0]["score"]+tp[1]["score"]
|
||||
if point < 0.85:
|
||||
continue
|
||||
keep_result.update({class_id:tp[0:2]})
|
||||
result_points.append({"id":class_id,"point":point})
|
||||
result_points.sort(key=lambda item: item['point'], reverse=True)
|
||||
if len(keep_result) > 3:
|
||||
tp = {}
|
||||
for one in result_points[0:3]:
|
||||
tp.update({one['id']:keep_result[one['id']]})
|
||||
keep_result = tp
|
||||
for class_id in keep_result:
|
||||
keep_result[class_id].sort(key=lambda item: item['box'][3], reverse=True)
|
||||
sorted_result = {}
|
||||
@@ -305,7 +346,6 @@ def predict_onnx_dfine(image,draw_result=False):
|
||||
print(f"图片顺序的中心点{points}")
|
||||
return points
|
||||
|
||||
|
||||
|
||||
print(f"使用推理设备: {ort.get_device()}")
|
||||
if int(os.environ.get("use_pdl",1)):
|
||||
|
||||
Reference in New Issue
Block a user