mirror of
https://github.com/luguoyixiazi/test_nine.git
synced 2025-12-05 14:42:49 +08:00
增加paddle模型
This commit is contained in:
42
PP-HGNetV2-B4.yaml
Normal file
42
PP-HGNetV2-B4.yaml
Normal file
@@ -0,0 +1,42 @@
|
||||
Global:
|
||||
model: PP-HGNetV2-B4
|
||||
mode: check_dataset # check_dataset/train/evaluate/predict
|
||||
dataset_dir: "dataset"
|
||||
device: gpu:0
|
||||
output: "output"
|
||||
|
||||
CheckDataset:
|
||||
convert:
|
||||
enable: False
|
||||
src_dataset_type: null
|
||||
split:
|
||||
enable: False
|
||||
train_percent: null
|
||||
val_percent: null
|
||||
|
||||
Train:
|
||||
num_classes: 91
|
||||
epochs_iters: 100
|
||||
batch_size: 64
|
||||
learning_rate: 0.05
|
||||
pretrain_weight_path: https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNetV2_B4_ssld_pretrained.pdparams
|
||||
warmup_steps: 5
|
||||
resume_path: null
|
||||
log_interval: 1
|
||||
eval_interval: 1
|
||||
save_interval: 5
|
||||
|
||||
Evaluate:
|
||||
weight_path: "output/best_model/best_model.pdparams"
|
||||
log_interval: 1
|
||||
|
||||
Export:
|
||||
weight_path: https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNetV2_B4_ssld_pretrained.pdparams
|
||||
# weight_path: "./output/best_model/inference/inference"
|
||||
|
||||
Predict:
|
||||
batch_size: 1
|
||||
model_dir: "output/best_model/inference"
|
||||
input: "cropped_3.jpg"
|
||||
kernel_option:
|
||||
run_mode: paddle
|
||||
50
README.md
50
README.md
@@ -16,28 +16,48 @@ api:https://github.com/ravizhan/geetest-v3-click-crack
|
||||
|
||||
### 1.安装依赖
|
||||
|
||||
如果要训练paddle的话还得安装paddlex及图像分类模块,安装看项目https://github.com/PaddlePaddle/PaddleX
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2.自行准备数据集,V3和V4有区别
|
||||
|
||||
- 数据集详情参考上面标注的项目,但是上面项目是V4数据集,V3没有demo,自行发挥吧,用V4正确率有点感人,或许可以试试别的模型看看能不能泛化
|
||||
##### a. 训练resnet18
|
||||
|
||||
- 如果要切V3的图用crop_image.py的crop_image_v3,切V4则使用crop_image,自行编写切图脚本
|
||||
- 数据集详情参考上面标注的项目,但是上面项目是V4数据集,V3没有demo,自行发挥吧,用V4练V3不改代码正确率有点感人
|
||||
- 主要是V4的尺寸和V3有差别,V4的api直接给两张图,一张是目标图,一张是九宫格,V3放在一起要切目标,且V3目标图清晰度很低,V4九宫格切了之后是100 * 86的图(去掉黑边),但是V3九宫格切的是112 * 112,不确定V4九宫格内容在V3基础上做了什么变换,反正改预处理就完事了
|
||||
|
||||
##### b. 训练PP-HGNetV2-B4
|
||||
|
||||
在paddle上随便找的,数据集格式如下,如果拿V4练V3,建议是多整点变换
|
||||
|
||||
```
|
||||
dataset
|
||||
├─images #所有图片存放路径
|
||||
├─label.txt #标签路径,每一行数据格式为 <序号>+<空格>+<类别>,如15 地球仪
|
||||
├─train.txt #训练图片,每一行数据格式为 <图片路径>+<空格>+<类别>,如images/001.jpg 0
|
||||
└─验证集和测试集同上
|
||||
```
|
||||
|
||||
##### c. 如果要切V3的图用crop_image.py的crop_image_v3,切V4则使用crop_image,自行编写切图脚本
|
||||
|
||||
### 3.训练模型
|
||||
|
||||
训练运行 `python train.py`
|
||||
- 训练resnet18运行 `python train.py`
|
||||
- 如果训练PP-HGNetV2-B4运行`python train_paddle.py`
|
||||
|
||||
### 4.模型转换为onnx
|
||||
|
||||
运行 `python convert.py`(自行进去修改需要转换的模型,一般是选loss小的)
|
||||
- 运行 `python convert.py`(自行进去修改需要转换的模型,一般是选loss小的)
|
||||
- paddle模型转换要装paddle2onnx,详情参见https://www.paddlepaddle.org.cn/documentation/docs/guides/advanced/model_to_onnx_cn.html
|
||||
|
||||
### 5.启动fastapi服务
|
||||
|
||||
运行 `python main.py`
|
||||
运行 `python main.py`(默认用的paddle的onnx模型,如果要用resnet18可以自己改注释)
|
||||
|
||||
由于轨迹问题,可能会出现验证正确但是结果失败,所以建议增加retry次数
|
||||
|
||||
### 6.api调用
|
||||
|
||||
@@ -46,18 +66,14 @@ python调用如:
|
||||
```python
|
||||
import httpx
|
||||
|
||||
res = httpx.get("http://127.0.0.1:9645/pass_nine",params={'gt':gt,'challenge':challenge},timeout=10)
|
||||
|
||||
datas = res.json()['data']
|
||||
|
||||
if datas['result'] == 'success':
|
||||
|
||||
return datas['validate']
|
||||
def game_captcha(gt: str, challenge: str):
|
||||
res = httpx.get("http://127.0.0.1:9645/pass_nine",params={'gt':gt,'challenge':challenge,'use_v3_model':True},timeout=10)
|
||||
datas = res.json()['data']
|
||||
if datas['result'] == 'success':
|
||||
return datas['validate']
|
||||
return None # 失败返回None 成功返回validate
|
||||
```
|
||||
|
||||
#### --宣传--
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
欢迎大家支持我的其他项目喵~~~~~~~~
|
||||
|
||||
@@ -132,7 +132,7 @@ if __name__ == "__main__":
|
||||
# wb.write(icon_img_jpg)
|
||||
|
||||
# V3测试代码
|
||||
pic = "./img_saved/f105965489de434e930fa1ef8a5bcd9f.jpg"
|
||||
pic = "img_saved/7fe559a85bac4c03bc6ea7b2e85325bf.jpg"
|
||||
print("推理图片为:",pic)
|
||||
with open(pic, "rb") as f:
|
||||
img = f.read()
|
||||
|
||||
21
main.py
21
main.py
@@ -15,7 +15,11 @@ app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/pass_nine")
|
||||
def get_pic(gt: str = Query(...), challenge: str = Query(...), point: str = Query(default=None)):
|
||||
def get_pic(gt: str = Query(...),
|
||||
challenge: str = Query(...),
|
||||
point: str = Query(default=None),
|
||||
use_v3_model = Query(default=True)
|
||||
):
|
||||
print(f"开始获取:\ngt:{gt}\nchallenge:{challenge}")
|
||||
t = time.time()
|
||||
|
||||
@@ -33,12 +37,15 @@ def get_pic(gt: str = Query(...), challenge: str = Query(...), point: str = Quer
|
||||
|
||||
crop_image_v3(pic_content)
|
||||
|
||||
with open(f"{validate_path}/cropped_9.jpg", "rb") as rb:
|
||||
icon_image = rb.read()
|
||||
with open(f"{validate_path}/nine.jpg", "rb") as rb:
|
||||
bg_image = rb.read()
|
||||
if use_v3_model:
|
||||
result_list = predict_onnx_pdl(validate_path)
|
||||
else:
|
||||
with open(f"{validate_path}/cropped_9.jpg", "rb") as rb:
|
||||
icon_image = rb.read()
|
||||
with open(f"{validate_path}/nine.jpg", "rb") as rb:
|
||||
bg_image = rb.read()
|
||||
result_list = predict_onnx(icon_image, bg_image, point)
|
||||
|
||||
result_list = predict_onnx(icon_image, bg_image, point)
|
||||
point_list = [f"{col}_{row}" for row, col in result_list]
|
||||
wait_time = 4.0 - (time.time() - t)
|
||||
time.sleep(wait_time)
|
||||
@@ -59,7 +66,7 @@ def get_pic(gt: str = Query(...), challenge: str = Query(...), point: str = Quer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from predict import predict_onnx
|
||||
from predict import predict_onnx,predict_onnx_pdl
|
||||
import uvicorn
|
||||
print(f"{' '*10}api: http://127.0.0.1:{port}/pass_nine{' '*10}")
|
||||
print(f"{' '*10}api所需参数:gt、challenge、point(可选){' '*10}")
|
||||
|
||||
98
predict.py
98
predict.py
@@ -64,14 +64,15 @@ def predict(icon_image, bg_image):
|
||||
print(largest_three)
|
||||
print("识别完成,耗时:", time.time() - start)
|
||||
|
||||
|
||||
# 加载onnx模型
|
||||
start = time.time()
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(current_dir, 'model', 'resnet18.onnx')
|
||||
session = ort.InferenceSession(model_path)
|
||||
input_name = session.get_inputs()[0].name
|
||||
print("加载模型,耗时:", time.time() - start)
|
||||
def load_model(name='PP-HGNetV2-B4.onnx'):
|
||||
# 加载onnx模型
|
||||
global session,input_name
|
||||
start = time.time()
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(current_dir, 'model', name)
|
||||
session = ort.InferenceSession(model_path)
|
||||
input_name = session.get_inputs()[0].name
|
||||
print("加载模型,耗时:", time.time() - start)
|
||||
|
||||
|
||||
def predict_onnx(icon_image, bg_image, point = None):
|
||||
@@ -145,15 +146,76 @@ def predict_onnx(icon_image, bg_image, point = None):
|
||||
draw_points_on_image(bg_image, answer)
|
||||
return answer
|
||||
|
||||
def predict_onnx_pdl(images_path):
|
||||
coordinates = [
|
||||
[1, 1],
|
||||
[1, 2],
|
||||
[1, 3],
|
||||
[2, 1],
|
||||
[2, 2],
|
||||
[2, 3],
|
||||
[3, 1],
|
||||
[3, 2],
|
||||
[3, 3],
|
||||
]
|
||||
def data_transforms(path):
|
||||
# 打开图片
|
||||
img = Image.open(path)
|
||||
# 调整图片大小为232x224(假设最短边长度调整为232像素)
|
||||
if img.width < img.height:
|
||||
new_size = (232, int(232 * img.height / img.width))
|
||||
else:
|
||||
new_size = (int(232 * img.width / img.height), 232)
|
||||
resized_img = img.resize(new_size, Image.BICUBIC)
|
||||
# 裁剪图片为224x224
|
||||
cropped_img = resized_img.crop((0, 0, 224, 224))
|
||||
# 将图像转换为NumPy数组并进行归一化处理
|
||||
img_array = np.array(cropped_img).astype(np.float32)
|
||||
img_array /= 255.0
|
||||
mean = [0.485, 0.456, 0.406]
|
||||
std = [0.229, 0.224, 0.225]
|
||||
img_array -= np.array(mean)
|
||||
img_array /= np.array(std)
|
||||
# 将通道维度移到前面
|
||||
img_array = np.transpose(img_array, (2, 0, 1))
|
||||
return img_array
|
||||
images = []
|
||||
for pic in sorted(os.listdir(images_path)):
|
||||
if "cropped" not in pic:
|
||||
continue
|
||||
image_path = os.path.join(images_path,pic)
|
||||
images.append(data_transforms(image_path))
|
||||
if len(images) == 0:
|
||||
raise FileNotFoundError(f"先使用切图代码切图至{image_path}再推理,图片命名如cropped_9.jpg,从0到9共十个,最后一个是检测目标")
|
||||
start = time.time()
|
||||
outputs = session.run(None, {input_name: images})[0]
|
||||
result = [np.argmax(one) for one in outputs]
|
||||
target = result[-1]
|
||||
answer = [coordinates[index] for index in range(9) if result[index] == target]
|
||||
print(f"识别完成{answer},耗时: {time.time() - start}")
|
||||
with open(os.path.join(images_path,"nine.jpg"),'rb') as f:
|
||||
bg_image = f.read()
|
||||
draw_points_on_image(bg_image, answer)
|
||||
return answer
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
icon_path = "img_2_val/cropped_9.jpg"
|
||||
bg_path = "img_2_val/nine.jpg"
|
||||
with open(icon_path, "rb") as rb:
|
||||
if icon_path.endswith('.png'):
|
||||
icon_image = convert_png_to_jpg(rb.read())
|
||||
else:
|
||||
icon_image = rb.read()
|
||||
with open(bg_path, "rb") as rb:
|
||||
bg_image = rb.read()
|
||||
predict_onnx(icon_image, bg_image)
|
||||
# 使用resnet18.onnx
|
||||
# load_model("resnet18.onnx")
|
||||
# icon_path = "img_2_val/cropped_9.jpg"
|
||||
# bg_path = "img_2_val/nine.jpg"
|
||||
# with open(icon_path, "rb") as rb:
|
||||
# if icon_path.endswith('.png'):
|
||||
# icon_image = convert_png_to_jpg(rb.read())
|
||||
# else:
|
||||
# icon_image = rb.read()
|
||||
# with open(bg_path, "rb") as rb:
|
||||
# bg_image = rb.read()
|
||||
# predict_onnx(icon_image, bg_image)
|
||||
|
||||
# 使用PP-HGNetV2-B4.onnx
|
||||
load_model()
|
||||
predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf')
|
||||
else:
|
||||
load_model()
|
||||
46
train_paddle.py
Normal file
46
train_paddle.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
|
||||
|
||||
from paddlex.utils.result_saver import try_except_decorator
|
||||
from paddlex.utils.config import parse_args, get_config
|
||||
from paddlex.utils.errors import raise_unsupported_api_error
|
||||
from paddlex.model import _ModelBasedConfig
|
||||
|
||||
print(f"""数据集格式如下:
|
||||
dataset
|
||||
├─images #所有图片存放路径
|
||||
├─label.txt #标签路径,每一行数据格式为 <序号>+<空格>+<类别>,如15 地球仪
|
||||
├─train.txt #训练图片,每一行数据格式为 <图片路径>+<空格>+<类别>,如images/001.jpg 0
|
||||
└─验证集和测试集同上
|
||||
""")
|
||||
class Engine(object):
|
||||
"""Engine"""
|
||||
|
||||
def __init__(self):
|
||||
args = parse_args()
|
||||
args.config='PP-HGNetV2-B4.yaml'
|
||||
args.override=['Global.mode=train', 'Global.dataset_dir=dataset']
|
||||
config = get_config(args.config, overrides=args.override, show=False)
|
||||
self._mode = config.Global.mode
|
||||
self._output = config.Global.output
|
||||
self._model = _ModelBasedConfig(config)
|
||||
|
||||
@try_except_decorator
|
||||
def run(self):
|
||||
"""the main function"""
|
||||
if self._mode == "check_dataset":
|
||||
return self._model.check_dataset()
|
||||
elif self._mode == "train":
|
||||
self._model.train()
|
||||
elif self._mode == "evaluate":
|
||||
return self._model.evaluate()
|
||||
elif self._mode == "export":
|
||||
return self._model.export()
|
||||
elif self._mode == "predict":
|
||||
for res in self._model.predict():
|
||||
res.print(json_format=False)
|
||||
else:
|
||||
raise_unsupported_api_error(f"{self._mode}", self.__class__)
|
||||
|
||||
if __name__ == "__main__":
|
||||
Engine().run()
|
||||
Reference in New Issue
Block a user