luguoyixiazi
2025-07-02 13:14:05 +08:00
committed by GitHub
parent 657096ee2d
commit 1beae61927
4 changed files with 341 additions and 78 deletions

View File

@@ -10,8 +10,8 @@ from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from crop_image import validate_path
from os import path as PATH
import os
os.environ['https_proxy']="http://127.0.0.1:10809"
class Crack:
def __init__(self, gt=None, challenge=None):
self.pic_path = None
@@ -440,9 +440,10 @@ Bm1Zzu+l8nSOqAurgQIDAQAB
pic_url = "https://" + data["resource_servers"][0][:-1] + data["pic"]
pic_data = self.session.get(pic_url).content
pic_name = data["pic"].split("/")[-1]
pic_type = data["pic_type"]
with open(PATH.join(validate_path,pic_name),'wb+') as f:
f.write(pic_data)
return pic_data,pic_name
return pic_data,pic_name,pic_type
def verify(self, points: list):
u = self.enc_key

View File

@@ -1,6 +1,5 @@
from PIL import Image, ImageFont, ImageDraw, ImageOps
from io import BytesIO
import cv2
import numpy as np
import os
current_path = os.getcwd()
@@ -12,8 +11,8 @@ os.makedirs(validate_path,exist_ok=True)
os.makedirs(save_path,exist_ok=True)
os.makedirs(save_pass_path,exist_ok=True)
os.makedirs(save_fail_path,exist_ok=True)
def draw_points_on_image(bg_image, answer):
import cv2
# 将背景图片转换为OpenCV格式
bg_image_cv = cv2.imdecode(np.frombuffer(bg_image, dtype=np.uint8), cv2.IMREAD_COLOR)
@@ -53,6 +52,10 @@ def convert_png_to_jpg(png_bytes: bytes) -> bytes:
# 返回保存后的 JPG 图像的 bytes
return output_bytes.getvalue()
def bytes_to_pil(image_bytes):
image = Image.open(BytesIO(image_bytes))
image = image.convert('RGB')
return image
def crop_image(image_bytes, coordinates):
img = Image.open(BytesIO(image_bytes))
@@ -84,8 +87,7 @@ def crop_image_v3(image_bytes):
[[232, 232], [344, 344]],#第三行
[[2, 344], [42, 384]] #要验证的
]
image = Image.open(BytesIO(image_bytes))
image = Image.fromarray(cv2.cvtColor(np.array(image), cv2.COLOR_RGBA2RGB))
image = bytes_to_pil(image_bytes)
imageNew = Image.new('RGB', (300,261),(0,0,0))
images = []
for i, (start_point, end_point) in enumerate(coordinates):

252
main.py
View File

@@ -1,43 +1,105 @@
import os
import json
import time
import random
from crack import Crack
from crop_image import crop_image_v3,save_path,save_fail_path,save_pass_path,validate_path
import httpx
from fastapi import FastAPI,Query
from fastapi.responses import JSONResponse
import shutil
import os
import random
import logging
import logging.config
from crack import Crack
from typing import Optional, Dict, Any
from contextlib import asynccontextmanager
from fastapi.responses import JSONResponse
from fastapi import FastAPI,Query, HTTPException
from predict import predict_onnx,predict_onnx_pdl,predict_onnx_dfine
from crop_image import crop_image_v3,save_path,save_fail_path,save_pass_path,validate_path
port = 9645
# api
app = FastAPI()
PORT = 9645
# --- 日志配置字典 ---
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"()": "uvicorn._logging.DefaultFormatter",
"fmt": "%(levelprefix)s %(asctime)s | %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S",
},
},
"handlers": {
"default": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
},
"loggers": {
# 将根日志记录器的级别设置为 INFO
"": {"handlers": ["default"], "level": "INFO"},
"uvicorn.error": {"level": "INFO"},
"uvicorn.access": {"handlers": ["default"], "level": "INFO", "propagate": False},
},
}
logging.config.dictConfig(LOGGING_CONFIG)
logger = logging.getLogger(__name__)
def get_available_hosts() -> set[str]:
"""获取本机所有可用的IPv4地址。"""
import socket
hosts = {"127.0.0.1"}
try:
hostname = socket.gethostname()
addr_info = socket.getaddrinfo(hostname, None, socket.AF_INET)
hosts.update({info[4][0] for info in addr_info})
except socket.gaierror:
try:
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("8.8.8.8", 80))
hosts.add(s.getsockname()[0])
except OSError:
pass
return hosts
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("="*50)
logger.info("启动服务中...")
# 从 uvicorn 配置中获取 host 和 port
server = app.servers[0] if app.servers else None
host = server.config.host if server else "0.0.0.0"
port = server.config.port if server else PORT
if host == "0.0.0.0":
available_hosts = get_available_hosts()
logger.info(f"服务地址(依需求选用docker中使用宿主机host:{port}):")
for h in sorted(list(available_hosts)):
logger.info(f" - http://{h}:{port}")
else:
logger.info(f"服务地址: http://{host}:{port}")
logger.info(f"可用服务路径如下:")
for route in app.routes:
logger.info(f" -{route.methods} {route.path}")
logger.info("="*50)
@app.get("/pass_nine")
def get_pic(gt: str = Query(...),
challenge: str = Query(...),
point: str = Query(default=None),
use_v3_model = Query(default=True),
save_result = Query(default=False)
):
print(f"开始获取:\ngt:{gt}\nchallenge:{challenge}")
t = time.time()
yield
logger.info("="*50)
logger.info("服务关闭")
logger.info("="*50)
app = FastAPI(title="极验V3图标点选+九宫格", lifespan=lifespan)
def prepare(gt: str, challenge: str) -> tuple[Crack, bytes, str, str]:
"""获取信息。"""
logging.info(f"开始获取:\ngt:{gt}\nchallenge:{challenge}")
crack = Crack(gt, challenge)
crack.gettype()
crack.get_c_s()
time.sleep(random.uniform(0.4,0.6))
crack.ajax()
pic_content,pic_name,pic_type = crack.get_pic()
return crack,pic_content,pic_name,pic_type
pic_content,pic_name = crack.get_pic()
def do_pass_nine(pic_content: bytes, use_v3_model: bool, point: Optional[str]) -> list[str]:
"""处理九宫格验证码,返回坐标点列表。"""
crop_image_v3(pic_content)
if use_v3_model:
result_list = predict_onnx_pdl(validate_path)
else:
@@ -46,30 +108,130 @@ def get_pic(gt: str = Query(...),
with open(f"{validate_path}/nine.jpg", "rb") as rb:
bg_image = rb.read()
result_list = predict_onnx(icon_image, bg_image, point)
return [f"{col}_{row}" for row, col in result_list]
point_list = [f"{col}_{row}" for row, col in result_list]
wait_time = max(0,4.0 - (time.time() - t))
time.sleep(wait_time)
result = json.loads(crack.verify(point_list))
if save_result:
shutil.move(os.path.join(validate_path,pic_name),os.path.join(save_path,pic_name))
if 'validate' in result['data']:
path_2_save = os.path.join(save_pass_path,pic_name.split('.')[0])
def do_pass_icon(pic:Any, draw_result: bool) -> list[str]:
"""处理图标点选验证码,返回坐标点列表。"""
result_list = predict_onnx_dfine(pic,draw_result)
print(result_list)
return [f"{round(x / 333 * 10000)}_{round(y / 333 * 10000)}" for x, y in result_list]
def save_image_for_train(pic_name,pic_type,passed):
shutil.move(os.path.join(validate_path,pic_name),os.path.join(save_path,pic_name))
if passed:
path_2_save = os.path.join(save_pass_path,pic_name.split('.')[0])
else:
path_2_save = os.path.join(save_fail_path,pic_name.split('.')[0])
os.makedirs(path_2_save,exist_ok=True)
for pic in os.listdir(validate_path):
if pic_type == "nine" and pic.startswith('cropped'):
shutil.move(os.path.join(validate_path,pic),os.path.join(path_2_save,pic))
if pic_type == "icon" and pic.startswith('icon'):
shutil.move(os.path.join(validate_path,pic),os.path.join(path_2_save,pic))
def handle_pass_request(gt: str, challenge: str, save_result: bool, **kwargs) -> JSONResponse:
"""
统一处理所有验证码请求的核心函数。
"""
start_time = time.monotonic()
try:
# 1. 准备
crack, pic_content, pic_name, pic_type = prepare(gt, challenge)
# 2. 识别
if pic_type == "nine":
point_list = do_pass_nine(
pic_content,
use_v3_model=kwargs.get("use_v3_model", True),
point=kwargs.get("point",None)
)
elif pic_type == "icon":
point_list = do_pass_icon(pic_content, save_result)
else:
path_2_save = os.path.join(save_fail_path,pic_name.split('.')[0])
os.makedirs(path_2_save,exist_ok=True)
for pic in os.listdir(validate_path):
if pic.startswith('cropped'):
shutil.move(os.path.join(validate_path,pic),os.path.join(path_2_save,pic))
total_time = time.time() - t
print(f"总计耗时(含等待{wait_time}s): {total_time}\n{result}")
return JSONResponse(content=result)
raise HTTPException(status_code=400, detail=f"Unknown picture type: {pic_type}")
# 3. 验证
elapsed = time.monotonic() - start_time
wait_time = max(0, 4.0 - elapsed)
time.sleep(wait_time)
response_str = crack.verify(point_list)
result = json.loads(response_str)
# 4. 后处理
passed = 'validate' in result.get('data', {})
if save_result:
save_image_for_train(pic_name, pic_type, passed)
else:
os.remove(os.path.join(validate_path,pic_name))
total_time = time.monotonic() - start_time
logging.info(
f"请求完成,耗时: {total_time:.2f}s (等待 {wait_time:.2f}s). "
f"结果: {result}"
)
return JSONResponse(content=result)
except Exception as e:
logging.error(f"服务错误: {e}", exc_info=True)
return JSONResponse(
status_code=500,
content={"error": "An internal server error occurred.", "detail": str(e)}
)
@app.get("/pass_nine")
def pass_nine(gt: str = Query(...),
challenge: str = Query(...),
point: str = Query(default=None),
use_v3_model = Query(default=True),
save_result = Query(default=False)
):
return handle_pass_request(
gt, challenge, save_result,
use_v3_model=use_v3_model, point=point
)
@app.get("/pass_icon")
def pass_icon(gt: str = Query(...),
challenge: str = Query(...),
save_result = Query(default=False)
):
return handle_pass_request(gt, challenge, save_result)
@app.get("/pass_uni")
def pass_uni(gt: str = Query(...),
challenge: str = Query(...),
save_result = Query(default=False)
):
return handle_pass_request(gt, challenge, save_result)
@app.get("/pass_hutao")
def pass_hutao(gt: str = Query(...),
challenge: str = Query(...),
save_result = Query(default=False)):
try:
# 调用原函数获取返回值
response = handle_pass_request(gt, challenge, save_result)
# 获取原始状态码和内容
original_status_code = response.status_code
original_content = json.loads(response.body.decode("utf-8"))
if original_status_code == 200 and original_content.get("status",False)=="success" and "validate" in original_content.get("data",{}):
rebuild_content = {"code":0,"data":{"gt":gt,"challenge":challenge,"validate":original_content["data"]["validate"]}}
else:
rebuild_content = {"code":1,"data":{"gt":gt,"challenge":challenge,"validate":original_content}}
return JSONResponse(content=rebuild_content, status_code=original_status_code)
except Exception as e:
logging.error(f"修改路由错误: {e}", exc_info=True)
return JSONResponse(
status_code=500,
content={"error": "An internal server error occurred.", "detail": str(e)}
)
if __name__ == "__main__":
from predict import predict_onnx,predict_onnx_pdl
import uvicorn
print(f"{' '*10}api: http://0.0.0.0:{port}/pass_nine{' '*10}")
print(f"{' '*10}api所需参数gt、challenge、point(可选){' '*10}")
uvicorn.run(app,host="0.0.0.0",port=port)
uvicorn.run(app,port=PORT)

View File

@@ -1,19 +1,14 @@
import os
import numpy as np
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image
from train import MyResNet18, data_transform
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image,bytes_to_pil,validate_path
import time
import cv2
from PIL import Image
from PIL import Image, ImageDraw
from io import BytesIO
import onnxruntime as ort
def predict(icon_image, bg_image):
from train import MyResNet18, data_transform
import torch
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', 'resnet18_38_0.021147585306924.pth')
@@ -74,10 +69,20 @@ def load_model(name='PP-HGNetV2-B4.onnx'):
model_path = os.path.join(current_dir, 'model', name)
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
print("加载模型,耗时:", time.time() - start)
print(f"加载{name}模型,耗时:{time.time() - start}")
def load_dfine_model(name='d-fine-n.onnx'):
# 加载onnx模型
global session_dfine
start = time.time()
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', name)
session_dfine = ort.InferenceSession(model_path)
print(f"加载{name}模型,耗时:{time.time() - start}")
def predict_onnx(icon_image, bg_image, point = None):
import cv2
coordinates = [
[1, 1],
[1, 2],
@@ -145,7 +150,7 @@ def predict_onnx(icon_image, bg_image, point = None):
else:
answer = [one[0] for one in sorted_arr if one[1] > point]
print(f"识别完成{answer},耗时: {time.time() - start}")
draw_points_on_image(bg_image, answer)
#draw_points_on_image(bg_image, answer)
return answer
def predict_onnx_pdl(images_path):
@@ -194,15 +199,110 @@ def predict_onnx_pdl(images_path):
result = [np.argmax(one) for one in outputs]
target = result[-1]
answer = [coordinates[index] for index in range(9) if result[index] == target]
if len(answer) == 0:
all_sort =[np.argsort(one) for one in outputs]
answer = [coordinates[index] for index in range(9) if all_sort[index][1] == target]
print(f"识别完成{answer},耗时: {time.time() - start}")
if os.path.exists(os.path.join(images_path,"nine.jpg")):
with open(os.path.join(images_path,"nine.jpg"),'rb') as f:
bg_image = f.read()
draw_points_on_image(bg_image, answer)
with open(os.path.join(images_path,"nine.jpg"),'rb') as f:
bg_image = f.read()
draw_points_on_image(bg_image, answer)
return answer
def predict_onnx_dfine(image,draw_result=False):
input_nodes = session_dfine.get_inputs()
output_nodes = session_dfine.get_outputs()
image_input_name = input_nodes[0].name
size_input_name = input_nodes[1].name
output_names = [node.name for node in output_nodes]
if isinstance(image,bytes):
im_pil = bytes_to_pil(image)
else:
im_pil = Image.open(image_path).convert("RGB")
w, h = im_pil.size
orig_size_np = np.array([[w, h]], dtype=np.int64)
im_resized = im_pil.resize((320, 320), Image.Resampling.BILINEAR)
im_data = np.array(im_resized, dtype=np.float32) / 255.0
im_data = im_data.transpose(2, 0, 1)
im_data = np.expand_dims(im_data, axis=0)
inputs = {
image_input_name: im_data,
size_input_name: orig_size_np
}
outputs = session_dfine.run(output_names, inputs)
output_map = {name: data for name, data in zip(output_names, outputs)}
labels = output_map['labels'][0]
boxes = output_map['boxes'][0]
scores = output_map['scores'][0]
colors = ["red", "blue", "green", "yellow", "white", "purple", "orange"]
mask = scores > 0.4
filtered_labels = labels[mask]
filtered_boxes = boxes[mask]
filtered_scores = scores[mask]
rebuild_color = {}
unique_labels = list(set(filtered_labels))
for i, l_val in enumerate(unique_labels):
class_id = int(l_val)
if class_id not in rebuild_color:
rebuild_color[class_id] = colors[i % len(colors)]
result = {k: [] for k in unique_labels}
for i, box in enumerate(filtered_boxes):
label_val = filtered_labels[i]
class_id = int(label_val)
color = rebuild_color[class_id]
score = filtered_scores[i]
result[class_id].append({
'box': box,
'label_val': label_val,
'score': score
})
for class_id in result:
result[class_id].sort(key=lambda item: item['box'][3], reverse=True)
sorted_result = {}
sorted_class_ids = sorted(result.keys(), key=lambda cid: result[cid][0]['box'][0])
for class_id in sorted_class_ids:
sorted_result[class_id] = result[class_id]
points = []
if draw_result:
draw = ImageDraw.Draw(im_pil)
for c1,class_id in enumerate(sorted_result):
items = sorted_result[class_id]
last_item = items[-1]
center_x = (last_item['box'][0] + last_item['box'][2]) / 2
center_y = (last_item['box'][1] + last_item['box'][3]) / 2
text_position_center = (center_x , center_y)
points.append(text_position_center)
if draw_result:
color = rebuild_color[class_id]
draw.point((center_x, center_y), fill=color)
text_center = f"{c1}"
draw.text(text_position_center, text_center, fill=color)
for c2,item in enumerate(items):
box = item['box']
score = item['score']
draw.rectangle(list(box), outline=color, width=1)
text = f"{class_id}_{c1}-{c2}: {score:.2f}"
text_position = (box[0] + 2, box[1] - 12 if box[1] > 12 else box[1] + 2)
draw.text(text_position, text, fill=color)
if draw_result:
save_path = os.path.join(validate_path,"icon_result.jpg")
im_pil.save(save_path)
print(f"图片可视化结果保存在{save_path}")
print(f"图片顺序的中心点{points}")
return points
print(f"使用推理设备: {ort.get_device()}")
if int(os.environ.get("use_pdl",1)):
load_model()
if int(os.environ.get("use_dfine",1)):
load_dfine_model()
if __name__ == "__main__":
# 使用resnet18.onnx
# load_model("resnet18.onnx")
@@ -218,7 +318,5 @@ if __name__ == "__main__":
# predict_onnx(icon_image, bg_image)
# 使用PP-HGNetV2-B4.onnx
load_model()
predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf')
else:
load_model()
#predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf')
predict_onnx_dfine(r"n:\爬点选\dataset\3f98ff0c91dd4882a8a24d451283ad96.jpg",True)