diff --git a/crack.py b/crack.py index 5347de7..c46c490 100644 --- a/crack.py +++ b/crack.py @@ -10,8 +10,8 @@ from cryptography.hazmat.primitives import padding, serialization from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 from crop_image import validate_path from os import path as PATH - - +import os +os.environ['https_proxy']="http://127.0.0.1:10809" class Crack: def __init__(self, gt=None, challenge=None): self.pic_path = None @@ -440,9 +440,10 @@ Bm1Zzu+l8nSOqAurgQIDAQAB pic_url = "https://" + data["resource_servers"][0][:-1] + data["pic"] pic_data = self.session.get(pic_url).content pic_name = data["pic"].split("/")[-1] + pic_type = data["pic_type"] with open(PATH.join(validate_path,pic_name),'wb+') as f: f.write(pic_data) - return pic_data,pic_name + return pic_data,pic_name,pic_type def verify(self, points: list): u = self.enc_key diff --git a/crop_image.py b/crop_image.py index 4dcebf2..706371f 100644 --- a/crop_image.py +++ b/crop_image.py @@ -1,6 +1,5 @@ from PIL import Image, ImageFont, ImageDraw, ImageOps from io import BytesIO -import cv2 import numpy as np import os current_path = os.getcwd() @@ -12,8 +11,8 @@ os.makedirs(validate_path,exist_ok=True) os.makedirs(save_path,exist_ok=True) os.makedirs(save_pass_path,exist_ok=True) os.makedirs(save_fail_path,exist_ok=True) - def draw_points_on_image(bg_image, answer): + import cv2 # 将背景图片转换为OpenCV格式 bg_image_cv = cv2.imdecode(np.frombuffer(bg_image, dtype=np.uint8), cv2.IMREAD_COLOR) @@ -53,6 +52,10 @@ def convert_png_to_jpg(png_bytes: bytes) -> bytes: # 返回保存后的 JPG 图像的 bytes return output_bytes.getvalue() +def bytes_to_pil(image_bytes): + image = Image.open(BytesIO(image_bytes)) + image = image.convert('RGB') + return image def crop_image(image_bytes, coordinates): img = Image.open(BytesIO(image_bytes)) @@ -84,8 +87,7 @@ def crop_image_v3(image_bytes): [[232, 232], [344, 344]],#第三行 [[2, 344], [42, 384]] #要验证的 ] - image = Image.open(BytesIO(image_bytes)) - image = Image.fromarray(cv2.cvtColor(np.array(image), cv2.COLOR_RGBA2RGB)) + image = bytes_to_pil(image_bytes) imageNew = Image.new('RGB', (300,261),(0,0,0)) images = [] for i, (start_point, end_point) in enumerate(coordinates): diff --git a/main.py b/main.py index 9c1ae89..f4f3192 100644 --- a/main.py +++ b/main.py @@ -1,43 +1,105 @@ +import os import json import time -import random -from crack import Crack -from crop_image import crop_image_v3,save_path,save_fail_path,save_pass_path,validate_path import httpx -from fastapi import FastAPI,Query -from fastapi.responses import JSONResponse import shutil -import os +import random +import logging +import logging.config +from crack import Crack +from typing import Optional, Dict, Any +from contextlib import asynccontextmanager +from fastapi.responses import JSONResponse +from fastapi import FastAPI,Query, HTTPException +from predict import predict_onnx,predict_onnx_pdl,predict_onnx_dfine +from crop_image import crop_image_v3,save_path,save_fail_path,save_pass_path,validate_path -port = 9645 -# api -app = FastAPI() +PORT = 9645 +# --- 日志配置字典 --- +LOGGING_CONFIG = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "()": "uvicorn._logging.DefaultFormatter", + "fmt": "%(levelprefix)s %(asctime)s | %(message)s", + "datefmt": "%Y-%m-%d %H:%M:%S", + }, + }, + "handlers": { + "default": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + }, + }, + "loggers": { + # 将根日志记录器的级别设置为 INFO + "": {"handlers": ["default"], "level": "INFO"}, + "uvicorn.error": {"level": "INFO"}, + "uvicorn.access": {"handlers": ["default"], "level": "INFO", "propagate": False}, + }, +} +logging.config.dictConfig(LOGGING_CONFIG) +logger = logging.getLogger(__name__) +def get_available_hosts() -> set[str]: + """获取本机所有可用的IPv4地址。""" + import socket + hosts = {"127.0.0.1"} + try: + hostname = socket.gethostname() + addr_info = socket.getaddrinfo(hostname, None, socket.AF_INET) + hosts.update({info[4][0] for info in addr_info}) + except socket.gaierror: + try: + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(("8.8.8.8", 80)) + hosts.add(s.getsockname()[0]) + except OSError: + pass + return hosts - -@app.get("/pass_nine") -def get_pic(gt: str = Query(...), - challenge: str = Query(...), - point: str = Query(default=None), - use_v3_model = Query(default=True), - save_result = Query(default=False) - ): - print(f"开始获取:\ngt:{gt}\nchallenge:{challenge}") - t = time.time() - - crack = Crack(gt, challenge) - - crack.gettype() - - crack.get_c_s() - - time.sleep(random.uniform(0.4,0.6)) - - crack.ajax() - - pic_content,pic_name = crack.get_pic() - - crop_image_v3(pic_content) +@asynccontextmanager +async def lifespan(app: FastAPI): + logger.info("="*50) + logger.info("启动服务中...") + # 从 uvicorn 配置中获取 host 和 port + server = app.servers[0] if app.servers else None + host = server.config.host if server else "0.0.0.0" + port = server.config.port if server else PORT + if host == "0.0.0.0": + available_hosts = get_available_hosts() + logger.info(f"服务地址(依需求选用,docker中使用宿主机host:{port}):") + for h in sorted(list(available_hosts)): + logger.info(f" - http://{h}:{port}") + else: + logger.info(f"服务地址: http://{host}:{port}") + logger.info(f"可用服务路径如下:") + for route in app.routes: + logger.info(f" -{route.methods} {route.path}") + logger.info("="*50) + yield + logger.info("="*50) + logger.info("服务关闭") + logger.info("="*50) + +app = FastAPI(title="极验V3图标点选+九宫格", lifespan=lifespan) + +def prepare(gt: str, challenge: str) -> tuple[Crack, bytes, str, str]: + """获取信息。""" + logging.info(f"开始获取:\ngt:{gt}\nchallenge:{challenge}") + crack = Crack(gt, challenge) + crack.gettype() + crack.get_c_s() + time.sleep(random.uniform(0.4,0.6)) + crack.ajax() + pic_content,pic_name,pic_type = crack.get_pic() + return crack,pic_content,pic_name,pic_type + +def do_pass_nine(pic_content: bytes, use_v3_model: bool, point: Optional[str]) -> list[str]: + """处理九宫格验证码,返回坐标点列表。""" + crop_image_v3(pic_content) if use_v3_model: result_list = predict_onnx_pdl(validate_path) else: @@ -46,30 +108,130 @@ def get_pic(gt: str = Query(...), with open(f"{validate_path}/nine.jpg", "rb") as rb: bg_image = rb.read() result_list = predict_onnx(icon_image, bg_image, point) + return [f"{col}_{row}" for row, col in result_list] + +def do_pass_icon(pic:Any, draw_result: bool) -> list[str]: + """处理图标点选验证码,返回坐标点列表。""" + result_list = predict_onnx_dfine(pic,draw_result) + print(result_list) + return [f"{round(x / 333 * 10000)}_{round(y / 333 * 10000)}" for x, y in result_list] + +def save_image_for_train(pic_name,pic_type,passed): + shutil.move(os.path.join(validate_path,pic_name),os.path.join(save_path,pic_name)) + if passed: + path_2_save = os.path.join(save_pass_path,pic_name.split('.')[0]) + else: + path_2_save = os.path.join(save_fail_path,pic_name.split('.')[0]) + os.makedirs(path_2_save,exist_ok=True) + for pic in os.listdir(validate_path): + if pic_type == "nine" and pic.startswith('cropped'): + shutil.move(os.path.join(validate_path,pic),os.path.join(path_2_save,pic)) + if pic_type == "icon" and pic.startswith('icon'): + shutil.move(os.path.join(validate_path,pic),os.path.join(path_2_save,pic)) + + +def handle_pass_request(gt: str, challenge: str, save_result: bool, **kwargs) -> JSONResponse: + """ + 统一处理所有验证码请求的核心函数。 + """ + start_time = time.monotonic() + try: + # 1. 准备 + crack, pic_content, pic_name, pic_type = prepare(gt, challenge) - point_list = [f"{col}_{row}" for row, col in result_list] - wait_time = max(0,4.0 - (time.time() - t)) - time.sleep(wait_time) - result = json.loads(crack.verify(point_list)) - if save_result: - shutil.move(os.path.join(validate_path,pic_name),os.path.join(save_path,pic_name)) - if 'validate' in result['data']: - path_2_save = os.path.join(save_pass_path,pic_name.split('.')[0]) + # 2. 识别 + + if pic_type == "nine": + point_list = do_pass_nine( + pic_content, + use_v3_model=kwargs.get("use_v3_model", True), + point=kwargs.get("point",None) + ) + elif pic_type == "icon": + point_list = do_pass_icon(pic_content, save_result) else: - path_2_save = os.path.join(save_fail_path,pic_name.split('.')[0]) - os.makedirs(path_2_save,exist_ok=True) - for pic in os.listdir(validate_path): - if pic.startswith('cropped'): - shutil.move(os.path.join(validate_path,pic),os.path.join(path_2_save,pic)) - total_time = time.time() - t - print(f"总计耗时(含等待{wait_time}s): {total_time}\n{result}") - return JSONResponse(content=result) + raise HTTPException(status_code=400, detail=f"Unknown picture type: {pic_type}") + # 3. 验证 + elapsed = time.monotonic() - start_time + wait_time = max(0, 4.0 - elapsed) + time.sleep(wait_time) + response_str = crack.verify(point_list) + result = json.loads(response_str) + + # 4. 后处理 + passed = 'validate' in result.get('data', {}) + if save_result: + save_image_for_train(pic_name, pic_type, passed) + else: + os.remove(os.path.join(validate_path,pic_name)) + + total_time = time.monotonic() - start_time + logging.info( + f"请求完成,耗时: {total_time:.2f}s (等待 {wait_time:.2f}s). " + f"结果: {result}" + ) + return JSONResponse(content=result) + + except Exception as e: + logging.error(f"服务错误: {e}", exc_info=True) + return JSONResponse( + status_code=500, + content={"error": "An internal server error occurred.", "detail": str(e)} + ) + + + +@app.get("/pass_nine") +def pass_nine(gt: str = Query(...), + challenge: str = Query(...), + point: str = Query(default=None), + use_v3_model = Query(default=True), + save_result = Query(default=False) + ): + return handle_pass_request( + gt, challenge, save_result, + use_v3_model=use_v3_model, point=point + ) + +@app.get("/pass_icon") +def pass_icon(gt: str = Query(...), + challenge: str = Query(...), + save_result = Query(default=False) + ): + return handle_pass_request(gt, challenge, save_result) + +@app.get("/pass_uni") +def pass_uni(gt: str = Query(...), + challenge: str = Query(...), + save_result = Query(default=False) + ): + return handle_pass_request(gt, challenge, save_result) + +@app.get("/pass_hutao") +def pass_hutao(gt: str = Query(...), + challenge: str = Query(...), + save_result = Query(default=False)): + try: + # 调用原函数获取返回值 + response = handle_pass_request(gt, challenge, save_result) + # 获取原始状态码和内容 + original_status_code = response.status_code + original_content = json.loads(response.body.decode("utf-8")) + if original_status_code == 200 and original_content.get("status",False)=="success" and "validate" in original_content.get("data",{}): + rebuild_content = {"code":0,"data":{"gt":gt,"challenge":challenge,"validate":original_content["data"]["validate"]}} + else: + rebuild_content = {"code":1,"data":{"gt":gt,"challenge":challenge,"validate":original_content}} + return JSONResponse(content=rebuild_content, status_code=original_status_code) + + except Exception as e: + logging.error(f"修改路由错误: {e}", exc_info=True) + return JSONResponse( + status_code=500, + content={"error": "An internal server error occurred.", "detail": str(e)} + ) if __name__ == "__main__": - from predict import predict_onnx,predict_onnx_pdl import uvicorn - print(f"{' '*10}api: http://0.0.0.0:{port}/pass_nine{' '*10}") - print(f"{' '*10}api所需参数:gt、challenge、point(可选){' '*10}") - uvicorn.run(app,host="0.0.0.0",port=port) + uvicorn.run(app,port=PORT) \ No newline at end of file diff --git a/predict.py b/predict.py index c3be884..42d1997 100644 --- a/predict.py +++ b/predict.py @@ -1,19 +1,14 @@ import os - import numpy as np - - -from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image - +from train import MyResNet18, data_transform +from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image,bytes_to_pil,validate_path import time -import cv2 -from PIL import Image +from PIL import Image, ImageDraw from io import BytesIO import onnxruntime as ort def predict(icon_image, bg_image): - from train import MyResNet18, data_transform import torch current_dir = os.path.dirname(os.path.abspath(__file__)) model_path = os.path.join(current_dir, 'model', 'resnet18_38_0.021147585306924.pth') @@ -74,10 +69,20 @@ def load_model(name='PP-HGNetV2-B4.onnx'): model_path = os.path.join(current_dir, 'model', name) session = ort.InferenceSession(model_path) input_name = session.get_inputs()[0].name - print("加载模型,耗时:", time.time() - start) + print(f"加载{name}模型,耗时:{time.time() - start}") + +def load_dfine_model(name='d-fine-n.onnx'): + # 加载onnx模型 + global session_dfine + start = time.time() + current_dir = os.path.dirname(os.path.abspath(__file__)) + model_path = os.path.join(current_dir, 'model', name) + session_dfine = ort.InferenceSession(model_path) + print(f"加载{name}模型,耗时:{time.time() - start}") def predict_onnx(icon_image, bg_image, point = None): + import cv2 coordinates = [ [1, 1], [1, 2], @@ -145,7 +150,7 @@ def predict_onnx(icon_image, bg_image, point = None): else: answer = [one[0] for one in sorted_arr if one[1] > point] print(f"识别完成{answer},耗时: {time.time() - start}") - draw_points_on_image(bg_image, answer) + #draw_points_on_image(bg_image, answer) return answer def predict_onnx_pdl(images_path): @@ -194,15 +199,110 @@ def predict_onnx_pdl(images_path): result = [np.argmax(one) for one in outputs] target = result[-1] answer = [coordinates[index] for index in range(9) if result[index] == target] + if len(answer) == 0: + all_sort =[np.argsort(one) for one in outputs] + answer = [coordinates[index] for index in range(9) if all_sort[index][1] == target] print(f"识别完成{answer},耗时: {time.time() - start}") - if os.path.exists(os.path.join(images_path,"nine.jpg")): - with open(os.path.join(images_path,"nine.jpg"),'rb') as f: - bg_image = f.read() - draw_points_on_image(bg_image, answer) + with open(os.path.join(images_path,"nine.jpg"),'rb') as f: + bg_image = f.read() + draw_points_on_image(bg_image, answer) return answer + +def predict_onnx_dfine(image,draw_result=False): + input_nodes = session_dfine.get_inputs() + output_nodes = session_dfine.get_outputs() + image_input_name = input_nodes[0].name + size_input_name = input_nodes[1].name + output_names = [node.name for node in output_nodes] + if isinstance(image,bytes): + im_pil = bytes_to_pil(image) + else: + im_pil = Image.open(image_path).convert("RGB") + w, h = im_pil.size + orig_size_np = np.array([[w, h]], dtype=np.int64) + im_resized = im_pil.resize((320, 320), Image.Resampling.BILINEAR) + im_data = np.array(im_resized, dtype=np.float32) / 255.0 + im_data = im_data.transpose(2, 0, 1) + im_data = np.expand_dims(im_data, axis=0) + inputs = { + image_input_name: im_data, + size_input_name: orig_size_np + } + outputs = session_dfine.run(output_names, inputs) + output_map = {name: data for name, data in zip(output_names, outputs)} + labels = output_map['labels'][0] + boxes = output_map['boxes'][0] + scores = output_map['scores'][0] + + colors = ["red", "blue", "green", "yellow", "white", "purple", "orange"] + mask = scores > 0.4 + filtered_labels = labels[mask] + filtered_boxes = boxes[mask] + filtered_scores = scores[mask] + + rebuild_color = {} + unique_labels = list(set(filtered_labels)) + for i, l_val in enumerate(unique_labels): + class_id = int(l_val) + if class_id not in rebuild_color: + rebuild_color[class_id] = colors[i % len(colors)] + + result = {k: [] for k in unique_labels} + for i, box in enumerate(filtered_boxes): + label_val = filtered_labels[i] + class_id = int(label_val) + color = rebuild_color[class_id] + score = filtered_scores[i] + + result[class_id].append({ + 'box': box, + 'label_val': label_val, + 'score': score + }) + for class_id in result: + result[class_id].sort(key=lambda item: item['box'][3], reverse=True) + sorted_result = {} + sorted_class_ids = sorted(result.keys(), key=lambda cid: result[cid][0]['box'][0]) + for class_id in sorted_class_ids: + sorted_result[class_id] = result[class_id] + points = [] + if draw_result: + draw = ImageDraw.Draw(im_pil) + for c1,class_id in enumerate(sorted_result): + items = sorted_result[class_id] + last_item = items[-1] + center_x = (last_item['box'][0] + last_item['box'][2]) / 2 + center_y = (last_item['box'][1] + last_item['box'][3]) / 2 + text_position_center = (center_x , center_y) + points.append(text_position_center) + if draw_result: + color = rebuild_color[class_id] + draw.point((center_x, center_y), fill=color) + text_center = f"{c1}" + draw.text(text_position_center, text_center, fill=color) + for c2,item in enumerate(items): + box = item['box'] + score = item['score'] + + draw.rectangle(list(box), outline=color, width=1) + text = f"{class_id}_{c1}-{c2}: {score:.2f}" + text_position = (box[0] + 2, box[1] - 12 if box[1] > 12 else box[1] + 2) + draw.text(text_position, text, fill=color) + if draw_result: + save_path = os.path.join(validate_path,"icon_result.jpg") + im_pil.save(save_path) + print(f"图片可视化结果保存在{save_path}") + print(f"图片顺序的中心点{points}") + return points + +print(f"使用推理设备: {ort.get_device()}") +if int(os.environ.get("use_pdl",1)): + load_model() +if int(os.environ.get("use_dfine",1)): + load_dfine_model() if __name__ == "__main__": # 使用resnet18.onnx # load_model("resnet18.onnx") @@ -218,7 +318,5 @@ if __name__ == "__main__": # predict_onnx(icon_image, bg_image) # 使用PP-HGNetV2-B4.onnx - load_model() - predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf') -else: - load_model() \ No newline at end of file + #predict_onnx_pdl(r'img_saved\img_fail\7fe559a85bac4c03bc6ea7b2e85325bf') + predict_onnx_dfine(r"n:\爬点选\dataset\3f98ff0c91dd4882a8a24d451283ad96.jpg",True)