diff --git a/PP-HGNetV2-B4.yaml b/PP-HGNetV2-B4.yaml new file mode 100644 index 0000000..6199fcd --- /dev/null +++ b/PP-HGNetV2-B4.yaml @@ -0,0 +1,42 @@ +Global: + model: PP-HGNetV2-B4 + mode: check_dataset # check_dataset/train/evaluate/predict + dataset_dir: "dataset" + device: gpu:0 + output: "output" + +CheckDataset: + convert: + enable: False + src_dataset_type: null + split: + enable: False + train_percent: null + val_percent: null + +Train: + num_classes: 91 + epochs_iters: 100 + batch_size: 64 + learning_rate: 0.05 + pretrain_weight_path: https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNetV2_B4_ssld_pretrained.pdparams + warmup_steps: 5 + resume_path: null + log_interval: 1 + eval_interval: 1 + save_interval: 5 + +Evaluate: + weight_path: "output/best_model/best_model.pdparams" + log_interval: 1 + +Export: + weight_path: https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNetV2_B4_ssld_pretrained.pdparams + # weight_path: "./output/best_model/inference/inference" + +Predict: + batch_size: 1 + model_dir: "output/best_model/inference" + input: "cropped_3.jpg" + kernel_option: + run_mode: paddle diff --git a/README.md b/README.md index 12dfcdd..143d959 100644 --- a/README.md +++ b/README.md @@ -16,28 +16,48 @@ api:https://github.com/ravizhan/geetest-v3-click-crack ### 1.安装依赖 +如果要训练paddle的话还得安装paddlex及图像分类模块,安装看项目https://github.com/PaddlePaddle/PaddleX + ``` pip install -r requirements.txt ``` ### 2.自行准备数据集,V3和V4有区别 -- 数据集详情参考上面标注的项目,但是上面项目是V4数据集,V3没有demo,自行发挥吧,用V4正确率有点感人,或许可以试试别的模型看看能不能泛化 +##### a. 训练resnet18 -- 如果要切V3的图用crop_image.py的crop_image_v3,切V4则使用crop_image,自行编写切图脚本 +- 数据集详情参考上面标注的项目,但是上面项目是V4数据集,V3没有demo,自行发挥吧,用V4练V3不改代码正确率有点感人 +- 主要是V4的尺寸和V3有差别,V4的api直接给两张图,一张是目标图,一张是九宫格,V3放在一起要切目标,且V3目标图清晰度很低,V4九宫格切了之后是100 * 86的图(去掉黑边),但是V3九宫格切的是112 * 112,不确定V4九宫格内容在V3基础上做了什么变换,反正改预处理就完事了 +##### b. 训练PP-HGNetV2-B4 + +在paddle上随便找的,数据集格式如下,如果拿V4练V3,建议是多整点变换 + +``` + dataset + ├─images #所有图片存放路径 + ├─label.txt #标签路径,每一行数据格式为 <序号>+<空格>+<类别>,如15 地球仪 + ├─train.txt #训练图片,每一行数据格式为 <图片路径>+<空格>+<类别>,如images/001.jpg 0 + └─验证集和测试集同上 +``` + +##### c. 如果要切V3的图用crop_image.py的crop_image_v3,切V4则使用crop_image,自行编写切图脚本 ### 3.训练模型 -训练运行 `python train.py` +- 训练resnet18运行 `python train.py` +- 如果训练PP-HGNetV2-B4运行`python train_paddle.py` ### 4.模型转换为onnx -运行 `python convert.py`(自行进去修改需要转换的模型,一般是选loss小的) +- 运行 `python convert.py`(自行进去修改需要转换的模型,一般是选loss小的) +- paddle模型转换要装paddle2onnx,详情参见https://www.paddlepaddle.org.cn/documentation/docs/guides/advanced/model_to_onnx_cn.html ### 5.启动fastapi服务 -运行 `python main.py` +运行 `python main.py`(默认用的paddle的onnx模型,如果要用resnet18可以自己改注释) + +由于轨迹问题,可能会出现验证正确但是结果失败,所以建议增加retry次数 ### 6.api调用 @@ -46,18 +66,14 @@ python调用如: ```python import httpx -res = httpx.get("http://127.0.0.1:9645/pass_nine",params={'gt':gt,'challenge':challenge},timeout=10) - -datas = res.json()['data'] - -if datas['result'] == 'success': - - return datas['validate'] +def game_captcha(gt: str, challenge: str): + res = httpx.get("http://127.0.0.1:9645/pass_nine",params={'gt':gt,'challenge':challenge,'use_v3_model':True},timeout=10) + datas = res.json()['data'] + if datas['result'] == 'success': + return datas['validate'] + return None # 失败返回None 成功返回validate ``` +#### --宣传-- - - - - - +欢迎大家支持我的其他项目喵~~~~~~~~ diff --git a/crop_image.py b/crop_image.py index 40b38ac..4dcebf2 100644 --- a/crop_image.py +++ b/crop_image.py @@ -132,7 +132,7 @@ if __name__ == "__main__": # wb.write(icon_img_jpg) # V3测试代码 - pic = "./img_saved/f105965489de434e930fa1ef8a5bcd9f.jpg" + pic = "img_saved/7fe559a85bac4c03bc6ea7b2e85325bf.jpg" print("推理图片为:",pic) with open(pic, "rb") as f: img = f.read() diff --git a/main.py b/main.py index 4691ec3..5d75557 100644 --- a/main.py +++ b/main.py @@ -15,7 +15,11 @@ app = FastAPI() @app.get("/pass_nine") -def get_pic(gt: str = Query(...), challenge: str = Query(...), point: str = Query(default=None)): +def get_pic(gt: str = Query(...), + challenge: str = Query(...), + point: str = Query(default=None), + use_v3_model = Query(default=True) + ): print(f"开始获取:\ngt:{gt}\nchallenge:{challenge}") t = time.time() @@ -32,13 +36,16 @@ def get_pic(gt: str = Query(...), challenge: str = Query(...), point: str = Quer pic_content,pic_name = crack.get_pic() crop_image_v3(pic_content) - - with open(f"{validate_path}/cropped_9.jpg", "rb") as rb: - icon_image = rb.read() - with open(f"{validate_path}/nine.jpg", "rb") as rb: - bg_image = rb.read() - result_list = predict_onnx(icon_image, bg_image, point) + if use_v3_model: + result_list = predict_onnx_pdl(validate_path) + else: + with open(f"{validate_path}/cropped_9.jpg", "rb") as rb: + icon_image = rb.read() + with open(f"{validate_path}/nine.jpg", "rb") as rb: + bg_image = rb.read() + result_list = predict_onnx(icon_image, bg_image, point) + point_list = [f"{col}_{row}" for row, col in result_list] wait_time = 4.0 - (time.time() - t) time.sleep(wait_time) @@ -59,7 +66,7 @@ def get_pic(gt: str = Query(...), challenge: str = Query(...), point: str = Quer if __name__ == "__main__": - from predict import predict_onnx + from predict import predict_onnx,predict_onnx_pdl import uvicorn print(f"{' '*10}api: http://127.0.0.1:{port}/pass_nine{' '*10}") print(f"{' '*10}api所需参数:gt、challenge、point(可选){' '*10}") diff --git a/predict.py b/predict.py index b195a87..7951e2e 100644 --- a/predict.py +++ b/predict.py @@ -64,14 +64,15 @@ def predict(icon_image, bg_image): print(largest_three) print("识别完成,耗时:", time.time() - start) - -# 加载onnx模型 -start = time.time() -current_dir = os.path.dirname(os.path.abspath(__file__)) -model_path = os.path.join(current_dir, 'model', 'resnet18.onnx') -session = ort.InferenceSession(model_path) -input_name = session.get_inputs()[0].name -print("加载模型,耗时:", time.time() - start) +def load_model(name='PP-HGNetV2-B4.onnx'): + # 加载onnx模型 + global session,input_name + start = time.time() + current_dir = os.path.dirname(os.path.abspath(__file__)) + model_path = os.path.join(current_dir, 'model', name) + session = ort.InferenceSession(model_path) + input_name = session.get_inputs()[0].name + print("加载模型,耗时:", time.time() - start) def predict_onnx(icon_image, bg_image, point = None): @@ -145,15 +146,76 @@ def predict_onnx(icon_image, bg_image, point = None): draw_points_on_image(bg_image, answer) return answer +def predict_onnx_pdl(images_path): + coordinates = [ + [1, 1], + [1, 2], + [1, 3], + [2, 1], + [2, 2], + [2, 3], + [3, 1], + [3, 2], + [3, 3], + ] + def data_transforms(path): + # 打开图片 + img = Image.open(path) + # 调整图片大小为232x224(假设最短边长度调整为232像素) + if img.width < img.height: + new_size = (232, int(232 * img.height / img.width)) + else: + new_size = (int(232 * img.width / img.height), 232) + resized_img = img.resize(new_size, Image.BICUBIC) + # 裁剪图片为224x224 + cropped_img = resized_img.crop((0, 0, 224, 224)) + # 将图像转换为NumPy数组并进行归一化处理 + img_array = np.array(cropped_img).astype(np.float32) + img_array /= 255.0 + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + img_array -= np.array(mean) + img_array /= np.array(std) + # 将通道维度移到前面 + img_array = np.transpose(img_array, (2, 0, 1)) + return img_array + images = [] + for pic in sorted(os.listdir(images_path)): + if "cropped" not in pic: + continue + image_path = os.path.join(images_path,pic) + images.append(data_transforms(image_path)) + if len(images) == 0: + raise FileNotFoundError(f"先使用切图代码切图至{image_path}再推理,图片命名如cropped_9.jpg,从0到9共十个,最后一个是检测目标") + start = time.time() + outputs = session.run(None, {input_name: images})[0] + result = [np.argmax(one) for one in outputs] + target = result[-1] + answer = [coordinates[index] for index in range(9) if result[index] == target] + print(f"识别完成{answer},耗时: {time.time() - start}") + with open(os.path.join(images_path,"nine.jpg"),'rb') as f: + bg_image = f.read() + draw_points_on_image(bg_image, answer) + return answer + + if __name__ == "__main__": - icon_path = "img_2_val/cropped_9.jpg" - bg_path = "img_2_val/nine.jpg" - with open(icon_path, "rb") as rb: - if icon_path.endswith('.png'): - icon_image = convert_png_to_jpg(rb.read()) - else: - icon_image = rb.read() - with open(bg_path, "rb") as rb: - bg_image = rb.read() - predict_onnx(icon_image, bg_image) \ No newline at end of file + # 使用resnet18.onnx + # load_model("resnet18.onnx") + # icon_path = "img_2_val/cropped_9.jpg" + # bg_path = "img_2_val/nine.jpg" + # with open(icon_path, "rb") as rb: + # if icon_path.endswith('.png'): + # icon_image = convert_png_to_jpg(rb.read()) + # else: + # icon_image = rb.read() + # with open(bg_path, "rb") as rb: + # bg_image = rb.read() + # predict_onnx(icon_image, bg_image) + + # 使用PP-HGNetV2-B4.onnx + load_model() + predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf') +else: + load_model() \ No newline at end of file diff --git a/train_paddle.py b/train_paddle.py new file mode 100644 index 0000000..a5c9d9e --- /dev/null +++ b/train_paddle.py @@ -0,0 +1,46 @@ +import os + + +from paddlex.utils.result_saver import try_except_decorator +from paddlex.utils.config import parse_args, get_config +from paddlex.utils.errors import raise_unsupported_api_error +from paddlex.model import _ModelBasedConfig + +print(f"""数据集格式如下: + dataset + ├─images #所有图片存放路径 + ├─label.txt #标签路径,每一行数据格式为 <序号>+<空格>+<类别>,如15 地球仪 + ├─train.txt #训练图片,每一行数据格式为 <图片路径>+<空格>+<类别>,如images/001.jpg 0 + └─验证集和测试集同上 + """) +class Engine(object): + """Engine""" + + def __init__(self): + args = parse_args() + args.config='PP-HGNetV2-B4.yaml' + args.override=['Global.mode=train', 'Global.dataset_dir=dataset'] + config = get_config(args.config, overrides=args.override, show=False) + self._mode = config.Global.mode + self._output = config.Global.output + self._model = _ModelBasedConfig(config) + + @try_except_decorator + def run(self): + """the main function""" + if self._mode == "check_dataset": + return self._model.check_dataset() + elif self._mode == "train": + self._model.train() + elif self._mode == "evaluate": + return self._model.evaluate() + elif self._mode == "export": + return self._model.export() + elif self._mode == "predict": + for res in self._model.predict(): + res.print(json_format=False) + else: + raise_unsupported_api_error(f"{self._mode}", self.__class__) + +if __name__ == "__main__": + Engine().run() \ No newline at end of file