结构微调

1.将host改为0.0.0.0,以便监听所有IP
2.修改https://github.com/Womsxd/MihoyoBBSTools/issues/198#issuecomment-2714437516等提出的时间可能为负数的情况
3.将使用resnet的推理代码整合,以精简docker镜像
This commit is contained in:
luguoyixiazi
2025-03-21 04:45:27 +08:00
committed by GitHub
parent aeaab3277d
commit 513f9dd247
2 changed files with 22 additions and 17 deletions

26
main.py
View File

@@ -18,7 +18,8 @@ app = FastAPI()
def get_pic(gt: str = Query(...),
challenge: str = Query(...),
point: str = Query(default=None),
use_v3_model = Query(default=True)
use_v3_model = Query(default=True),
save_result = Query(default=False)
):
print(f"开始获取:\ngt:{gt}\nchallenge:{challenge}")
t = time.time()
@@ -47,18 +48,19 @@ def get_pic(gt: str = Query(...),
result_list = predict_onnx(icon_image, bg_image, point)
point_list = [f"{col}_{row}" for row, col in result_list]
wait_time = 4.0 - (time.time() - t)
wait_time = max(0,4.0 - (time.time() - t))
time.sleep(wait_time)
result = json.loads(crack.verify(point_list))
shutil.move(os.path.join(validate_path,pic_name),os.path.join(save_path,pic_name))
if 'validate' in result['data']:
path_2_save = os.path.join(save_pass_path,pic_name.split('.')[0])
else:
path_2_save = os.path.join(save_fail_path,pic_name.split('.')[0])
os.makedirs(path_2_save,exist_ok=True)
for pic in os.listdir(validate_path):
if pic.startswith('cropped'):
shutil.move(os.path.join(validate_path,pic),os.path.join(path_2_save,pic))
if save_result:
shutil.move(os.path.join(validate_path,pic_name),os.path.join(save_path,pic_name))
if 'validate' in result['data']:
path_2_save = os.path.join(save_pass_path,pic_name.split('.')[0])
else:
path_2_save = os.path.join(save_fail_path,pic_name.split('.')[0])
os.makedirs(path_2_save,exist_ok=True)
for pic in os.listdir(validate_path):
if pic.startswith('cropped'):
shutil.move(os.path.join(validate_path,pic),os.path.join(path_2_save,pic))
total_time = time.time() - t
print(f"总计耗时(含等待{wait_time}s): {total_time}\n{result}")
return JSONResponse(content=result)
@@ -68,6 +70,6 @@ def get_pic(gt: str = Query(...),
if __name__ == "__main__":
from predict import predict_onnx,predict_onnx_pdl
import uvicorn
print(f"{' '*10}api: http://127.0.0.1:{port}/pass_nine{' '*10}")
print(f"{' '*10}api: http://0.0.0.0:{port}/pass_nine{' '*10}")
print(f"{' '*10}api所需参数gt、challenge、point(可选){' '*10}")
uvicorn.run(app,host="0.0.0.0",port=port)