mirror of
https://github.com/luguoyixiazi/test_nine.git
synced 2025-12-05 14:42:49 +08:00
结构微调
1.将host改为0.0.0.0,以便监听所有IP 2.修改https://github.com/Womsxd/MihoyoBBSTools/issues/198#issuecomment-2714437516等提出的时间可能为负数的情况 3.将使用resnet的推理代码整合,以精简docker镜像
This commit is contained in:
26
main.py
26
main.py
@@ -18,7 +18,8 @@ app = FastAPI()
|
||||
def get_pic(gt: str = Query(...),
|
||||
challenge: str = Query(...),
|
||||
point: str = Query(default=None),
|
||||
use_v3_model = Query(default=True)
|
||||
use_v3_model = Query(default=True),
|
||||
save_result = Query(default=False)
|
||||
):
|
||||
print(f"开始获取:\ngt:{gt}\nchallenge:{challenge}")
|
||||
t = time.time()
|
||||
@@ -47,18 +48,19 @@ def get_pic(gt: str = Query(...),
|
||||
result_list = predict_onnx(icon_image, bg_image, point)
|
||||
|
||||
point_list = [f"{col}_{row}" for row, col in result_list]
|
||||
wait_time = 4.0 - (time.time() - t)
|
||||
wait_time = max(0,4.0 - (time.time() - t))
|
||||
time.sleep(wait_time)
|
||||
result = json.loads(crack.verify(point_list))
|
||||
shutil.move(os.path.join(validate_path,pic_name),os.path.join(save_path,pic_name))
|
||||
if 'validate' in result['data']:
|
||||
path_2_save = os.path.join(save_pass_path,pic_name.split('.')[0])
|
||||
else:
|
||||
path_2_save = os.path.join(save_fail_path,pic_name.split('.')[0])
|
||||
os.makedirs(path_2_save,exist_ok=True)
|
||||
for pic in os.listdir(validate_path):
|
||||
if pic.startswith('cropped'):
|
||||
shutil.move(os.path.join(validate_path,pic),os.path.join(path_2_save,pic))
|
||||
if save_result:
|
||||
shutil.move(os.path.join(validate_path,pic_name),os.path.join(save_path,pic_name))
|
||||
if 'validate' in result['data']:
|
||||
path_2_save = os.path.join(save_pass_path,pic_name.split('.')[0])
|
||||
else:
|
||||
path_2_save = os.path.join(save_fail_path,pic_name.split('.')[0])
|
||||
os.makedirs(path_2_save,exist_ok=True)
|
||||
for pic in os.listdir(validate_path):
|
||||
if pic.startswith('cropped'):
|
||||
shutil.move(os.path.join(validate_path,pic),os.path.join(path_2_save,pic))
|
||||
total_time = time.time() - t
|
||||
print(f"总计耗时(含等待{wait_time}s): {total_time}\n{result}")
|
||||
return JSONResponse(content=result)
|
||||
@@ -68,6 +70,6 @@ def get_pic(gt: str = Query(...),
|
||||
if __name__ == "__main__":
|
||||
from predict import predict_onnx,predict_onnx_pdl
|
||||
import uvicorn
|
||||
print(f"{' '*10}api: http://127.0.0.1:{port}/pass_nine{' '*10}")
|
||||
print(f"{' '*10}api: http://0.0.0.0:{port}/pass_nine{' '*10}")
|
||||
print(f"{' '*10}api所需参数:gt、challenge、point(可选){' '*10}")
|
||||
uvicorn.run(app,host="0.0.0.0",port=port)
|
||||
|
||||
13
predict.py
13
predict.py
@@ -2,9 +2,9 @@ import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from train import MyResNet18, data_transform
|
||||
|
||||
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image
|
||||
import torch
|
||||
|
||||
import time
|
||||
import cv2
|
||||
from PIL import Image
|
||||
@@ -13,6 +13,8 @@ import onnxruntime as ort
|
||||
|
||||
|
||||
def predict(icon_image, bg_image):
|
||||
from train import MyResNet18, data_transform
|
||||
import torch
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(current_dir, 'model', 'resnet18_38_0.021147585306924.pth')
|
||||
coordinates = [
|
||||
@@ -193,9 +195,10 @@ def predict_onnx_pdl(images_path):
|
||||
target = result[-1]
|
||||
answer = [coordinates[index] for index in range(9) if result[index] == target]
|
||||
print(f"识别完成{answer},耗时: {time.time() - start}")
|
||||
with open(os.path.join(images_path,"nine.jpg"),'rb') as f:
|
||||
bg_image = f.read()
|
||||
draw_points_on_image(bg_image, answer)
|
||||
if os.path.exists(os.path.join(images_path,"nine.jpg")):
|
||||
with open(os.path.join(images_path,"nine.jpg"),'rb') as f:
|
||||
bg_image = f.read()
|
||||
draw_points_on_image(bg_image, answer)
|
||||
return answer
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user