Compare commits

...

5 Commits

5 changed files with 370 additions and 90 deletions

View File

@@ -1,4 +1,4 @@
# 九宫格测试代码
# 九宫格+点选测试代码
## **本项目仅供学习交流使用,请勿用于商业用途,否则后果自负。**
@@ -10,6 +10,8 @@
模型及V4数据集https://github.com/taisuii/ClassificationCaptchaOcr
点选检测模型https://github.com/Peterande/D-FINE
apihttps://github.com/ravizhan/geetest-v3-click-crack
## 运行步骤
@@ -17,8 +19,8 @@ apihttps://github.com/ravizhan/geetest-v3-click-crack
### 1.安装依赖本地必选使用docker跳至[5-b](#docker)
可选a.如果要训练paddle的话还得安装paddlex及图像分类模块安装看项目https://github.com/PaddlePaddle/PaddleX
* 必选!)b.模型需要在项目目录下新建一个model文件夹然后把模型文件放进去具体命名可以是resnet18.onnx或者PP-HGNetV2-B4.onnx默认使用PP-HGNetV2-B4模型如果用resnet则use_v3_model设置为False因为模型的输入输出不一样可以自行修改
可选b.d-fine训练看项目https://github.com/Peterande/D-FINE
* 必选!)c.模型需要在项目目录下新建一个model文件夹然后把模型文件放进去具体命名可以是resnet18.onnx或者PP-HGNetV2-B4.onnx默认使用PP-HGNetV2-B4模型如果用resnet则use_v3_model设置为False因为模型的输入输出不一样可以自行修改d-fine的模型同样放置此路径用于通过点选
```
pip install -r requirements.txt
@@ -29,7 +31,7 @@ pip install -r requirements.txt
pip install -r requirements_without_train.txt
```
### 2.自行准备数据集V3和V4有区别可选
### 2.自行准备数据集V3和V4有区别可选,点选可以自己生成,要有旋转、重叠、换色
##### a. 训练resnet18可选
@@ -48,28 +50,39 @@ pip install -r requirements_without_train.txt
└─验证集和测试集同上
```
##### c. 如果要切V3的图用crop_image.py的crop_image_v3切V4则使用crop_image自行编写切图脚本
##### b. 训练d-fine可选
数据集格式如d-fine中标识如果不修改源码则num_classes需+1采用coco格式即可我用的320*320dataloader注释掉了RandomZoomOut、RandomHorizontalFlip、RandomIoUCrop这些我全写在数据集生成中了
##### c. 如果要切V3的九宫格图用crop_image.py的crop_image_v3切V4则使用crop_image自行编写切图脚本
### 3.训练模型(可选)
- 训练resnet18运行 `python train.py`
- 如果训练PP-HGNetV2-B4运行`python train_paddle.py`
- 训练d-fine参照原项目一个模型拿下比ddddocr+相似性检测资源开销小点
### 4.模型转换为onnx可选
### 4-a.PP-HGNetV2-B4模型和resnet模型转换为onnx可选
- 运行 `python convert.py`自行进去修改需要转换的模型一般是选loss小的
- paddle模型转换要装paddle2onnx详情参见https://www.paddlepaddle.org.cn/documentation/docs/guides/advanced/model_to_onnx_cn.html
### 4-b.d-fine转换为onnx可选
- 依原项目转换
- 推理时图像预处理应于训练时一致d-fine仓库中onnx推理的预处理和训练不一致……
### 5-a.启动fastapi服务必须要有训练完成的onnx格式模型
运行 `python main.py`默认用的paddle的onnx模型如果要用resnet18可以自己改注释
运行 `python main.py`默认用的paddle的onnx模型如果要用resnet18可以自己改注释或者`uvicorn main:app --host 0.0.0.0 --port 9645 --reload`
由于轨迹问题可能会出现验证正确但是结果失败所以建议增加retry次数训练后的paddle模型正确率在99.9%以上
### 5-b.使用docker启动服务
镜像地址为<span id="docker">luguoyixiazi/test_nine:25.6.20</span>
镜像地址为<span id="docker">luguoyixiazi/test_nine:25.7.2</span>
运行时只需指定绑定的port即可api端口为/pass_nine必填参数gt、challenge
运行时只需指定绑定的port和两个环境变量`use_pdl``use_dfine`1为启用模型0为不启用默认均启用api端口为/pass_uni必填参数gt、challenge单独的pass_nine和pass_icon也写了有更多可选参数
### 6.api调用
@@ -79,15 +92,19 @@ python调用如
import httpx
def game_captcha(gt: str, challenge: str):
res = httpx.get("http://127.0.0.1:9645/pass_nine",params={'gt':gt,'challenge':challenge,'use_v3_model':True,"save_result":False},timeout=10)
res = httpx.get("http://127.0.0.1:9645/pass_uni",params={'gt':gt,'challenge':challenge},timeout=10)
# 或者依旧使用pass_nine路径
# res = httpx.get("http://127.0.0.1:9645/pass_nine",params={'gt':gt,'challenge':challenge,'use_v3_model':True,"save_result":False},timeout=10)
datas = res.json()['data']
if datas['result'] == 'success':
return datas['validate']
return None # 失败返回None 成功返回validate
```
在snap hutao中的服务端口为/pass_hutao,返回值已做对齐填写api如`(http://127.0.0.1:9645/pass_hutao?gt={0}&challenge={1})`即可
具体调用代码看使用项目此处示例仅为API url和参数示例
#### --宣传--
欢迎大家支持我的其他项目喵~~~~~~~~
欢迎大家支持我的其他项目(搭配使用)喵~~~~~~~~

View File

@@ -10,8 +10,7 @@ from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from crop_image import validate_path
from os import path as PATH
import os
class Crack:
def __init__(self, gt=None, challenge=None):
self.pic_path = None
@@ -440,9 +439,10 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
pic_url = "https://" + data["resource_servers"][0][:-1] + data["pic"]
pic_data = self.session.get(pic_url).content
pic_name = data["pic"].split("/")[-1]
pic_type = data["pic_type"]
with open(PATH.join(validate_path,pic_name),'wb+') as f:
f.write(pic_data)
return pic_data,pic_name
return pic_data,pic_name,pic_type
def verify(self, points: list):
u = self.enc_key
@@ -483,4 +483,4 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
"w": w
}
resp = self.session.get("https://api.geevisit.com/ajax.php", params=params).text
return resp[1:-1]
return resp[1:-1]

View File

@@ -1,6 +1,5 @@
from PIL import Image, ImageFont, ImageDraw, ImageOps
from io import BytesIO
import cv2
import numpy as np
import os
current_path = os.getcwd()
@@ -12,8 +11,8 @@ os.makedirs(validate_path,exist_ok=True)
os.makedirs(save_path,exist_ok=True)
os.makedirs(save_pass_path,exist_ok=True)
os.makedirs(save_fail_path,exist_ok=True)
def draw_points_on_image(bg_image, answer):
import cv2
# 将背景图片转换为OpenCV格式
bg_image_cv = cv2.imdecode(np.frombuffer(bg_image, dtype=np.uint8), cv2.IMREAD_COLOR)
@@ -53,6 +52,10 @@ def convert_png_to_jpg(png_bytes: bytes) -> bytes:
# 返回保存后的 JPG 图像的 bytes
return output_bytes.getvalue()
def bytes_to_pil(image_bytes):
image = Image.open(BytesIO(image_bytes))
image = image.convert('RGB')
return image
def crop_image(image_bytes, coordinates):
img = Image.open(BytesIO(image_bytes))
@@ -84,8 +87,7 @@ def crop_image_v3(image_bytes):
[[232, 232], [344, 344]],#第三行
[[2, 344], [42, 384]] #要验证的
]
image = Image.open(BytesIO(image_bytes))
image = Image.fromarray(cv2.cvtColor(np.array(image), cv2.COLOR_RGBA2RGB))
image = bytes_to_pil(image_bytes)
imageNew = Image.new('RGB', (300,261),(0,0,0))
images = []
for i, (start_point, end_point) in enumerate(coordinates):

269
main.py
View File

@@ -1,43 +1,106 @@
import os
import json
import time
import random
from crack import Crack
from crop_image import crop_image_v3,save_path,save_fail_path,save_pass_path,validate_path
import httpx
from fastapi import FastAPI,Query
from fastapi.responses import JSONResponse
import shutil
import os
import random
import logging
import logging.config
from crack import Crack
from typing import Optional, Dict, Any
from contextlib import asynccontextmanager
from fastapi.responses import JSONResponse
from fastapi import FastAPI,Query, HTTPException
from predict import predict_onnx,predict_onnx_pdl,predict_onnx_dfine
from crop_image import crop_image_v3,save_path,save_fail_path,save_pass_path,validate_path
port = 9645
# api
app = FastAPI()
PORT = 9645
platform = os.name
# --- 日志配置字典 ---
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"()": f"uvicorn.{'_logging' if platform== 'nt' else 'logging'}.DefaultFormatter",
"fmt": "%(levelprefix)s %(asctime)s | %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S",
},
},
"handlers": {
"default": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
},
"loggers": {
# 将根日志记录器的级别设置为 INFO
"": {"handlers": ["default"], "level": "INFO"},
"uvicorn.error": {"level": "INFO"},
"uvicorn.access": {"handlers": ["default"], "level": "INFO", "propagate": False},
},
}
logging.config.dictConfig(LOGGING_CONFIG)
logger = logging.getLogger(__name__)
def get_available_hosts() -> set[str]:
"""获取本机所有可用的IPv4地址。"""
import socket
hosts = {"127.0.0.1"}
try:
hostname = socket.gethostname()
addr_info = socket.getaddrinfo(hostname, None, socket.AF_INET)
hosts.update({info[4][0] for info in addr_info})
except socket.gaierror:
try:
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("8.8.8.8", 80))
hosts.add(s.getsockname()[0])
except OSError:
pass
return hosts
@app.get("/pass_nine")
def get_pic(gt: str = Query(...),
challenge: str = Query(...),
point: str = Query(default=None),
use_v3_model = Query(default=True),
save_result = Query(default=False)
):
print(f"开始获取:\ngt:{gt}\nchallenge:{challenge}")
t = time.time()
crack = Crack(gt, challenge)
crack.gettype()
crack.get_c_s()
time.sleep(random.uniform(0.4,0.6))
crack.ajax()
pic_content,pic_name = crack.get_pic()
crop_image_v3(pic_content)
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("="*50)
logger.info("启动服务中...")
# 从 uvicorn 配置中获取 host 和 port
server = app.servers[0] if app.servers else None
host = server.config.host if server else "0.0.0.0"
port = server.config.port if server else PORT
if host == "0.0.0.0":
available_hosts = get_available_hosts()
logger.info(f"服务地址(依需求选用docker中使用宿主机host:{port}):")
for h in sorted(list(available_hosts)):
logger.info(f" - http://{h}:{port}")
else:
logger.info(f"服务地址: http://{host}:{port}")
logger.info(f"可用服务路径如下:")
for route in app.routes:
logger.info(f" -{route.methods} {route.path}")
logger.info("="*50)
yield
logger.info("="*50)
logger.info("服务关闭")
logger.info("="*50)
app = FastAPI(title="极验V3图标点选+九宫格", lifespan=lifespan)
def prepare(gt: str, challenge: str) -> tuple[Crack, bytes, str, str]:
"""获取信息。"""
logging.info(f"开始获取:\ngt:{gt}\nchallenge:{challenge}")
crack = Crack(gt, challenge)
crack.gettype()
crack.get_c_s()
time.sleep(random.uniform(0.4,0.6))
crack.ajax()
pic_content,pic_name,pic_type = crack.get_pic()
return crack,pic_content,pic_name,pic_type
def do_pass_nine(pic_content: bytes, use_v3_model: bool, point: Optional[str]) -> list[str]:
"""处理九宫格验证码,返回坐标点列表。"""
crop_image_v3(pic_content)
if use_v3_model:
result_list = predict_onnx_pdl(validate_path)
else:
@@ -46,30 +109,130 @@ def get_pic(gt: str = Query(...),
with open(f"{validate_path}/nine.jpg", "rb") as rb:
bg_image = rb.read()
result_list = predict_onnx(icon_image, bg_image, point)
return [f"{col}_{row}" for row, col in result_list]
def do_pass_icon(pic:Any, draw_result: bool) -> list[str]:
"""处理图标点选验证码,返回坐标点列表。"""
result_list = predict_onnx_dfine(pic,draw_result)
print(result_list)
return [f"{round(x / 333 * 10000)}_{round(y / 333 * 10000)}" for x, y in result_list]
def save_image_for_train(pic_name,pic_type,passed):
shutil.move(os.path.join(validate_path,pic_name),os.path.join(save_path,pic_name))
if passed:
path_2_save = os.path.join(save_pass_path,pic_name.split('.')[0])
else:
path_2_save = os.path.join(save_fail_path,pic_name.split('.')[0])
os.makedirs(path_2_save,exist_ok=True)
for pic in os.listdir(validate_path):
if pic_type == "nine" and pic.startswith('cropped'):
shutil.move(os.path.join(validate_path,pic),os.path.join(path_2_save,pic))
if pic_type == "icon" and pic.startswith('icon'):
shutil.move(os.path.join(validate_path,pic),os.path.join(path_2_save,pic))
def handle_pass_request(gt: str, challenge: str, save_result: bool, **kwargs) -> JSONResponse:
"""
统一处理所有验证码请求的核心函数。
"""
start_time = time.monotonic()
try:
# 1. 准备
crack, pic_content, pic_name, pic_type = prepare(gt, challenge)
point_list = [f"{col}_{row}" for row, col in result_list]
wait_time = max(0,4.0 - (time.time() - t))
time.sleep(wait_time)
result = json.loads(crack.verify(point_list))
if save_result:
shutil.move(os.path.join(validate_path,pic_name),os.path.join(save_path,pic_name))
if 'validate' in result['data']:
path_2_save = os.path.join(save_pass_path,pic_name.split('.')[0])
# 2. 识别
if pic_type == "nine":
point_list = do_pass_nine(
pic_content,
use_v3_model=kwargs.get("use_v3_model", True),
point=kwargs.get("point",None)
)
elif pic_type == "icon":
point_list = do_pass_icon(pic_content, save_result)
else:
path_2_save = os.path.join(save_fail_path,pic_name.split('.')[0])
os.makedirs(path_2_save,exist_ok=True)
for pic in os.listdir(validate_path):
if pic.startswith('cropped'):
shutil.move(os.path.join(validate_path,pic),os.path.join(path_2_save,pic))
total_time = time.time() - t
print(f"总计耗时(含等待{wait_time}s): {total_time}\n{result}")
return JSONResponse(content=result)
raise HTTPException(status_code=400, detail=f"Unknown picture type: {pic_type}")
# 3. 验证
elapsed = time.monotonic() - start_time
wait_time = max(0, 4.0 - elapsed)
time.sleep(wait_time)
response_str = crack.verify(point_list)
result = json.loads(response_str)
# 4. 后处理
passed = 'validate' in result.get('data', {})
if save_result:
save_image_for_train(pic_name, pic_type, passed)
else:
os.remove(os.path.join(validate_path,pic_name))
total_time = time.monotonic() - start_time
logging.info(
f"请求完成,耗时: {total_time:.2f}s (等待 {wait_time:.2f}s). "
f"结果: {result}"
)
return JSONResponse(content=result)
except Exception as e:
logging.error(f"服务错误: {e}", exc_info=True)
return JSONResponse(
status_code=500,
content={"error": "An internal server error occurred.", "detail": str(e)}
)
@app.get("/pass_nine")
def pass_nine(gt: str = Query(...),
challenge: str = Query(...),
point: str = Query(default=None),
use_v3_model = Query(default=True),
save_result = Query(default=False)
):
return handle_pass_request(
gt, challenge, save_result,
use_v3_model=use_v3_model, point=point
)
@app.get("/pass_icon")
def pass_icon(gt: str = Query(...),
challenge: str = Query(...),
save_result = Query(default=False)
):
return handle_pass_request(gt, challenge, save_result)
@app.get("/pass_uni")
def pass_uni(gt: str = Query(...),
challenge: str = Query(...),
save_result = Query(default=False)
):
return handle_pass_request(gt, challenge, save_result)
@app.get("/pass_hutao")
def pass_hutao(gt: str = Query(...),
challenge: str = Query(...),
save_result = Query(default=False)):
try:
# 调用原函数获取返回值
response = handle_pass_request(gt, challenge, save_result)
# 获取原始状态码和内容
original_status_code = response.status_code
original_content = json.loads(response.body.decode("utf-8"))
if original_status_code == 200 and original_content.get("status",False)=="success" and "validate" in original_content.get("data",{}):
rebuild_content = {"code":0,"data":{"gt":gt,"challenge":challenge,"validate":original_content["data"]["validate"]}}
else:
rebuild_content = {"code":1,"data":{"gt":gt,"challenge":challenge,"validate":original_content}}
return JSONResponse(content=rebuild_content, status_code=original_status_code)
except Exception as e:
logging.error(f"修改路由错误: {e}", exc_info=True)
return JSONResponse(
status_code=500,
content={"error": "An internal server error occurred.", "detail": str(e)}
)
if __name__ == "__main__":
from predict import predict_onnx,predict_onnx_pdl
import uvicorn
print(f"{' '*10}api: http://0.0.0.0:{port}/pass_nine{' '*10}")
print(f"{' '*10}api所需参数gt、challenge、point(可选){' '*10}")
uvicorn.run(app,host="0.0.0.0",port=port)
uvicorn.run(app,port=PORT)

View File

@@ -1,20 +1,15 @@
import os
import numpy as np
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image,bytes_to_pil,validate_path
import time
import cv2
from PIL import Image
from PIL import Image, ImageDraw
from io import BytesIO
import onnxruntime as ort
def predict(icon_image, bg_image):
from train import MyResNet18, data_transform
import torch
from train import MyResNet18, data_transform
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', 'resnet18_38_0.021147585306924.pth')
coordinates = [
@@ -74,10 +69,20 @@ def load_model(name='PP-HGNetV2-B4.onnx'):
model_path = os.path.join(current_dir, 'model', name)
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
print("加载模型,耗时:", time.time() - start)
print(f"加载{name}模型,耗时:{time.time() - start}")
def load_dfine_model(name='d-fine-n.onnx'):
# 加载onnx模型
global session_dfine
start = time.time()
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', name)
session_dfine = ort.InferenceSession(model_path)
print(f"加载{name}模型,耗时:{time.time() - start}")
def predict_onnx(icon_image, bg_image, point = None):
import cv2
coordinates = [
[1, 1],
[1, 2],
@@ -145,7 +150,7 @@ def predict_onnx(icon_image, bg_image, point = None):
else:
answer = [one[0] for one in sorted_arr if one[1] > point]
print(f"识别完成{answer},耗时: {time.time() - start}")
draw_points_on_image(bg_image, answer)
#draw_points_on_image(bg_image, answer)
return answer
def predict_onnx_pdl(images_path):
@@ -194,15 +199,110 @@ def predict_onnx_pdl(images_path):
result = [np.argmax(one) for one in outputs]
target = result[-1]
answer = [coordinates[index] for index in range(9) if result[index] == target]
if len(answer) == 0:
all_sort =[np.argsort(one) for one in outputs]
answer = [coordinates[index] for index in range(9) if all_sort[index][1] == target]
print(f"识别完成{answer},耗时: {time.time() - start}")
if os.path.exists(os.path.join(images_path,"nine.jpg")):
with open(os.path.join(images_path,"nine.jpg"),'rb') as f:
bg_image = f.read()
draw_points_on_image(bg_image, answer)
with open(os.path.join(images_path,"nine.jpg"),'rb') as f:
bg_image = f.read()
draw_points_on_image(bg_image, answer)
return answer
def predict_onnx_dfine(image,draw_result=False):
input_nodes = session_dfine.get_inputs()
output_nodes = session_dfine.get_outputs()
image_input_name = input_nodes[0].name
size_input_name = input_nodes[1].name
output_names = [node.name for node in output_nodes]
if isinstance(image,bytes):
im_pil = bytes_to_pil(image)
else:
im_pil = Image.open(image_path).convert("RGB")
w, h = im_pil.size
orig_size_np = np.array([[w, h]], dtype=np.int64)
im_resized = im_pil.resize((320, 320), Image.Resampling.BILINEAR)
im_data = np.array(im_resized, dtype=np.float32) / 255.0
im_data = im_data.transpose(2, 0, 1)
im_data = np.expand_dims(im_data, axis=0)
inputs = {
image_input_name: im_data,
size_input_name: orig_size_np
}
outputs = session_dfine.run(output_names, inputs)
output_map = {name: data for name, data in zip(output_names, outputs)}
labels = output_map['labels'][0]
boxes = output_map['boxes'][0]
scores = output_map['scores'][0]
colors = ["red", "blue", "green", "yellow", "white", "purple", "orange"]
mask = scores > 0.4
filtered_labels = labels[mask]
filtered_boxes = boxes[mask]
filtered_scores = scores[mask]
rebuild_color = {}
unique_labels = list(set(filtered_labels))
for i, l_val in enumerate(unique_labels):
class_id = int(l_val)
if class_id not in rebuild_color:
rebuild_color[class_id] = colors[i % len(colors)]
result = {k: [] for k in unique_labels}
for i, box in enumerate(filtered_boxes):
label_val = filtered_labels[i]
class_id = int(label_val)
color = rebuild_color[class_id]
score = filtered_scores[i]
result[class_id].append({
'box': box,
'label_val': label_val,
'score': score
})
for class_id in result:
result[class_id].sort(key=lambda item: item['box'][3], reverse=True)
sorted_result = {}
sorted_class_ids = sorted(result.keys(), key=lambda cid: result[cid][0]['box'][0])
for class_id in sorted_class_ids:
sorted_result[class_id] = result[class_id]
points = []
if draw_result:
draw = ImageDraw.Draw(im_pil)
for c1,class_id in enumerate(sorted_result):
items = sorted_result[class_id]
last_item = items[-1]
center_x = (last_item['box'][0] + last_item['box'][2]) / 2
center_y = (last_item['box'][1] + last_item['box'][3]) / 2
text_position_center = (center_x , center_y)
points.append(text_position_center)
if draw_result:
color = rebuild_color[class_id]
draw.point((center_x, center_y), fill=color)
text_center = f"{c1}"
draw.text(text_position_center, text_center, fill=color)
for c2,item in enumerate(items):
box = item['box']
score = item['score']
draw.rectangle(list(box), outline=color, width=1)
text = f"{class_id}_{c1}-{c2}: {score:.2f}"
text_position = (box[0] + 2, box[1] - 12 if box[1] > 12 else box[1] + 2)
draw.text(text_position, text, fill=color)
if draw_result:
save_path = os.path.join(validate_path,"icon_result.jpg")
im_pil.save(save_path)
print(f"图片可视化结果保存在{save_path}")
print(f"图片顺序的中心点{points}")
return points
print(f"使用推理设备: {ort.get_device()}")
if int(os.environ.get("use_pdl",1)):
load_model()
if int(os.environ.get("use_dfine",1)):
load_dfine_model()
if __name__ == "__main__":
# 使用resnet18.onnx
# load_model("resnet18.onnx")
@@ -218,7 +318,5 @@ if __name__ == "__main__":
# predict_onnx(icon_image, bg_image)
# 使用PP-HGNetV2-B4.onnx
load_model()
predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf')
else:
load_model()
#predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf')
predict_onnx_dfine(r"n:\爬点选\dataset\3f98ff0c91dd4882a8a24d451283ad96.jpg",True)