增加paddle模型

This commit is contained in:
luguoyixiazi
2024-11-02 02:22:29 +08:00
committed by GitHub
parent bfb1ace0ee
commit 0141a1ccf0
6 changed files with 217 additions and 44 deletions

42
PP-HGNetV2-B4.yaml Normal file
View File

@@ -0,0 +1,42 @@
Global:
model: PP-HGNetV2-B4
mode: check_dataset # check_dataset/train/evaluate/predict
dataset_dir: "dataset"
device: gpu:0
output: "output"
CheckDataset:
convert:
enable: False
src_dataset_type: null
split:
enable: False
train_percent: null
val_percent: null
Train:
num_classes: 91
epochs_iters: 100
batch_size: 64
learning_rate: 0.05
pretrain_weight_path: https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNetV2_B4_ssld_pretrained.pdparams
warmup_steps: 5
resume_path: null
log_interval: 1
eval_interval: 1
save_interval: 5
Evaluate:
weight_path: "output/best_model/best_model.pdparams"
log_interval: 1
Export:
weight_path: https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNetV2_B4_ssld_pretrained.pdparams
# weight_path: "./output/best_model/inference/inference"
Predict:
batch_size: 1
model_dir: "output/best_model/inference"
input: "cropped_3.jpg"
kernel_option:
run_mode: paddle

View File

@@ -16,28 +16,48 @@ apihttps://github.com/ravizhan/geetest-v3-click-crack
### 1.安装依赖
如果要训练paddle的话还得安装paddlex及图像分类模块安装看项目https://github.com/PaddlePaddle/PaddleX
```
pip install -r requirements.txt
```
### 2.自行准备数据集V3和V4有区别
- 数据集详情参考上面标注的项目但是上面项目是V4数据集V3没有demo自行发挥吧用V4正确率有点感人或许可以试试别的模型看看能不能泛化
##### a. 训练resnet18
- 如果要切V3的图用crop_image.py的crop_image_v3切V4则使用crop_image自行编写切图脚本
- 数据集详情参考上面标注的项目但是上面项目是V4数据集V3没有demo自行发挥吧用V4练V3不改代码正确率有点感人
- 主要是V4的尺寸和V3有差别V4的api直接给两张图一张是目标图一张是九宫格V3放在一起要切目标且V3目标图清晰度很低V4九宫格切了之后是100 * 86的图去掉黑边但是V3九宫格切的是112 * 112不确定V4九宫格内容在V3基础上做了什么变换反正改预处理就完事了
##### b. 训练PP-HGNetV2-B4
在paddle上随便找的数据集格式如下如果拿V4练V3建议是多整点变换
```
dataset
├─images #所有图片存放路径
├─label.txt #标签路径,每一行数据格式为 <序号>+<空格>+<类别>如15 地球仪
├─train.txt #训练图片,每一行数据格式为 <图片路径>+<空格>+<类别>如images/001.jpg 0
└─验证集和测试集同上
```
##### c. 如果要切V3的图用crop_image.py的crop_image_v3切V4则使用crop_image自行编写切图脚本
### 3.训练模型
训练运行 `python train.py`
- 训练resnet18运行 `python train.py`
- 如果训练PP-HGNetV2-B4运行`python train_paddle.py`
### 4.模型转换为onnx
运行 `python convert.py`自行进去修改需要转换的模型一般是选loss小的
- 运行 `python convert.py`自行进去修改需要转换的模型一般是选loss小的
- paddle模型转换要装paddle2onnx详情参见https://www.paddlepaddle.org.cn/documentation/docs/guides/advanced/model_to_onnx_cn.html
### 5.启动fastapi服务
运行 `python main.py`
运行 `python main.py`默认用的paddle的onnx模型如果要用resnet18可以自己改注释
由于轨迹问题可能会出现验证正确但是结果失败所以建议增加retry次数
### 6.api调用
@@ -46,18 +66,14 @@ python调用如
```python
import httpx
res = httpx.get("http://127.0.0.1:9645/pass_nine",params={'gt':gt,'challenge':challenge},timeout=10)
datas = res.json()['data']
if datas['result'] == 'success':
return datas['validate']
def game_captcha(gt: str, challenge: str):
res = httpx.get("http://127.0.0.1:9645/pass_nine",params={'gt':gt,'challenge':challenge,'use_v3_model':True},timeout=10)
datas = res.json()['data']
if datas['result'] == 'success':
return datas['validate']
return None # 失败返回None 成功返回validate
```
#### --宣传--
欢迎大家支持我的其他项目喵~~~~~~~~

View File

@@ -132,7 +132,7 @@ if __name__ == "__main__":
# wb.write(icon_img_jpg)
# V3测试代码
pic = "./img_saved/f105965489de434e930fa1ef8a5bcd9f.jpg"
pic = "img_saved/7fe559a85bac4c03bc6ea7b2e85325bf.jpg"
print("推理图片为:",pic)
with open(pic, "rb") as f:
img = f.read()

23
main.py
View File

@@ -15,7 +15,11 @@ app = FastAPI()
@app.get("/pass_nine")
def get_pic(gt: str = Query(...), challenge: str = Query(...), point: str = Query(default=None)):
def get_pic(gt: str = Query(...),
challenge: str = Query(...),
point: str = Query(default=None),
use_v3_model = Query(default=True)
):
print(f"开始获取:\ngt:{gt}\nchallenge:{challenge}")
t = time.time()
@@ -32,13 +36,16 @@ def get_pic(gt: str = Query(...), challenge: str = Query(...), point: str = Quer
pic_content,pic_name = crack.get_pic()
crop_image_v3(pic_content)
with open(f"{validate_path}/cropped_9.jpg", "rb") as rb:
icon_image = rb.read()
with open(f"{validate_path}/nine.jpg", "rb") as rb:
bg_image = rb.read()
result_list = predict_onnx(icon_image, bg_image, point)
if use_v3_model:
result_list = predict_onnx_pdl(validate_path)
else:
with open(f"{validate_path}/cropped_9.jpg", "rb") as rb:
icon_image = rb.read()
with open(f"{validate_path}/nine.jpg", "rb") as rb:
bg_image = rb.read()
result_list = predict_onnx(icon_image, bg_image, point)
point_list = [f"{col}_{row}" for row, col in result_list]
wait_time = 4.0 - (time.time() - t)
time.sleep(wait_time)
@@ -59,7 +66,7 @@ def get_pic(gt: str = Query(...), challenge: str = Query(...), point: str = Quer
if __name__ == "__main__":
from predict import predict_onnx
from predict import predict_onnx,predict_onnx_pdl
import uvicorn
print(f"{' '*10}api: http://127.0.0.1:{port}/pass_nine{' '*10}")
print(f"{' '*10}api所需参数gt、challenge、point(可选){' '*10}")

View File

@@ -64,14 +64,15 @@ def predict(icon_image, bg_image):
print(largest_three)
print("识别完成,耗时:", time.time() - start)
# 加载onnx模型
start = time.time()
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', 'resnet18.onnx')
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
print("加载模型,耗时:", time.time() - start)
def load_model(name='PP-HGNetV2-B4.onnx'):
# 加载onnx模型
global session,input_name
start = time.time()
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', name)
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
print("加载模型,耗时:", time.time() - start)
def predict_onnx(icon_image, bg_image, point = None):
@@ -145,15 +146,76 @@ def predict_onnx(icon_image, bg_image, point = None):
draw_points_on_image(bg_image, answer)
return answer
def predict_onnx_pdl(images_path):
coordinates = [
[1, 1],
[1, 2],
[1, 3],
[2, 1],
[2, 2],
[2, 3],
[3, 1],
[3, 2],
[3, 3],
]
def data_transforms(path):
# 打开图片
img = Image.open(path)
# 调整图片大小为232x224假设最短边长度调整为232像素
if img.width < img.height:
new_size = (232, int(232 * img.height / img.width))
else:
new_size = (int(232 * img.width / img.height), 232)
resized_img = img.resize(new_size, Image.BICUBIC)
# 裁剪图片为224x224
cropped_img = resized_img.crop((0, 0, 224, 224))
# 将图像转换为NumPy数组并进行归一化处理
img_array = np.array(cropped_img).astype(np.float32)
img_array /= 255.0
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img_array -= np.array(mean)
img_array /= np.array(std)
# 将通道维度移到前面
img_array = np.transpose(img_array, (2, 0, 1))
return img_array
images = []
for pic in sorted(os.listdir(images_path)):
if "cropped" not in pic:
continue
image_path = os.path.join(images_path,pic)
images.append(data_transforms(image_path))
if len(images) == 0:
raise FileNotFoundError(f"先使用切图代码切图至{image_path}再推理,图片命名如cropped_9.jpg,从0到9共十个,最后一个是检测目标")
start = time.time()
outputs = session.run(None, {input_name: images})[0]
result = [np.argmax(one) for one in outputs]
target = result[-1]
answer = [coordinates[index] for index in range(9) if result[index] == target]
print(f"识别完成{answer},耗时: {time.time() - start}")
with open(os.path.join(images_path,"nine.jpg"),'rb') as f:
bg_image = f.read()
draw_points_on_image(bg_image, answer)
return answer
if __name__ == "__main__":
icon_path = "img_2_val/cropped_9.jpg"
bg_path = "img_2_val/nine.jpg"
with open(icon_path, "rb") as rb:
if icon_path.endswith('.png'):
icon_image = convert_png_to_jpg(rb.read())
else:
icon_image = rb.read()
with open(bg_path, "rb") as rb:
bg_image = rb.read()
predict_onnx(icon_image, bg_image)
# 使用resnet18.onnx
# load_model("resnet18.onnx")
# icon_path = "img_2_val/cropped_9.jpg"
# bg_path = "img_2_val/nine.jpg"
# with open(icon_path, "rb") as rb:
# if icon_path.endswith('.png'):
# icon_image = convert_png_to_jpg(rb.read())
# else:
# icon_image = rb.read()
# with open(bg_path, "rb") as rb:
# bg_image = rb.read()
# predict_onnx(icon_image, bg_image)
# 使用PP-HGNetV2-B4.onnx
load_model()
predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf')
else:
load_model()

46
train_paddle.py Normal file
View File

@@ -0,0 +1,46 @@
import os
from paddlex.utils.result_saver import try_except_decorator
from paddlex.utils.config import parse_args, get_config
from paddlex.utils.errors import raise_unsupported_api_error
from paddlex.model import _ModelBasedConfig
print(f"""数据集格式如下:
dataset
├─images #所有图片存放路径
├─label.txt #标签路径,每一行数据格式为 <序号>+<空格>+<类别>如15 地球仪
├─train.txt #训练图片,每一行数据格式为 <图片路径>+<空格>+<类别>如images/001.jpg 0
└─验证集和测试集同上
""")
class Engine(object):
"""Engine"""
def __init__(self):
args = parse_args()
args.config='PP-HGNetV2-B4.yaml'
args.override=['Global.mode=train', 'Global.dataset_dir=dataset']
config = get_config(args.config, overrides=args.override, show=False)
self._mode = config.Global.mode
self._output = config.Global.output
self._model = _ModelBasedConfig(config)
@try_except_decorator
def run(self):
"""the main function"""
if self._mode == "check_dataset":
return self._model.check_dataset()
elif self._mode == "train":
self._model.train()
elif self._mode == "evaluate":
return self._model.evaluate()
elif self._mode == "export":
return self._model.export()
elif self._mode == "predict":
for res in self._model.predict():
res.print(json_format=False)
else:
raise_unsupported_api_error(f"{self._mode}", self.__class__)
if __name__ == "__main__":
Engine().run()