This commit is contained in:
luguoyixiazi
2024-11-01 01:18:01 +08:00
committed by GitHub
commit bfb1ace0ee
91 changed files with 3397 additions and 0 deletions

63
README.md Normal file
View File

@@ -0,0 +1,63 @@
# 九宫格测试代码
## **本项目仅供学习交流使用,请勿用于商业用途,否则后果自负。**
## **本项目仅供学习交流使用,请勿用于商业用途,否则后果自负。**
## **本项目仅供学习交流使用,请勿用于商业用途,否则后果自负。**
## 参考项目
模型及V4数据集https://github.com/taisuii/ClassificationCaptchaOcr
apihttps://github.com/ravizhan/geetest-v3-click-crack
## 运行步骤
### 1.安装依赖
```
pip install -r requirements.txt
```
### 2.自行准备数据集V3和V4有区别
- 数据集详情参考上面标注的项目但是上面项目是V4数据集V3没有demo自行发挥吧用V4正确率有点感人或许可以试试别的模型看看能不能泛化
- 如果要切V3的图用crop_image.py的crop_image_v3切V4则使用crop_image自行编写切图脚本
### 3.训练模型
训练运行 `python train.py`
### 4.模型转换为onnx
运行 `python convert.py`自行进去修改需要转换的模型一般是选loss小的
### 5.启动fastapi服务
运行 `python main.py`
### 6.api调用
python调用如
```python
import httpx
res = httpx.get("http://127.0.0.1:9645/pass_nine",params={'gt':gt,'challenge':challenge},timeout=10)
datas = res.json()['data']
if datas['result'] == 'success':
return datas['validate']
```

0
__init__.py Normal file
View File

17
convert.py Normal file
View File

@@ -0,0 +1,17 @@
from train import MyResNet18
import torch
def convert():
# 加载 PyTorch 模型
model_path = "model/resnet18_39_0.01445627337038193.pth"
model = MyResNet18(num_classes=91)
model.load_state_dict(torch.load(model_path))
model.eval()
# 生成一个示例输入
dummy_input = torch.randn(10, 3, 224, 224)
# 将模型转换为 ONNX 格式
torch.onnx.export(model, dummy_input, "model/resnet18.onnx", verbose=True)
if __name__ == '__main__':
convert()

486
crack.py Normal file
View File

@@ -0,0 +1,486 @@
import hashlib
import json
import math
import random
import time
import httpx
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import padding, serialization
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from crop_image import validate_path
from os import path as PATH
class Crack:
def __init__(self, gt=None, challenge=None):
self.pic_path = None
self.s = None
self.c = None
self.session = httpx.Client(http2=True)
self.session.headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36"
}
# self.session.verify = False
self.gt = gt
self.challenge = challenge
self.aeskey = ''.join(f'{int((1 + random.random()) * 65536):04x}'[1:] for _ in range(4))
public_key = '''-----BEGIN PUBLIC KEY-----
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDB45NNFhRGWzMFPn9I7k7IexS5
XviJR3E9Je7L/350x5d9AtwdlFH3ndXRwQwprLaptNb7fQoCebZxnhdyVl8Jr2J3
FZGSIa75GJnK4IwNaG10iyCjYDviMYymvCtZcGWSqSGdC/Bcn2UCOiHSMwgHJSrg
Bm1Zzu+l8nSOqAurgQIDAQAB
-----END PUBLIC KEY-----'''
self.public_key = serialization.load_pem_public_key(public_key.encode())
self.enc_key = self.public_key.encrypt(self.aeskey.encode(), PKCS1v15()).hex()
with open("mousepath.json", "r") as f:
self.mouse_path = json.loads(f.read())
def get_type(self) -> dict:
url = f"https://api.geetest.com/gettype.php?gt={self.gt}"
res = self.session.get(url)
data = json.loads(res.text[1:-1])["data"]
return data
@staticmethod
def encode(input_bytes: list):
def get_char_from_index(index):
char_table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789()"
return char_table[index] if 0 <= index < len(char_table) else "."
def transform_value(value, bit_mask):
result = 0
for r in range(23, -1, -1):
if (bit_mask >> r) & 1:
result = (result << 1) + ((value >> r) & 1)
return result
encoded_string = ""
padding = ""
input_length = len(input_bytes)
for i in range(0, input_length, 3):
chunk_length = min(3, input_length - i)
chunk = input_bytes[i:i + chunk_length]
if chunk_length == 3:
value = (chunk[0] << 16) + (chunk[1] << 8) + chunk[2]
encoded_string += get_char_from_index(transform_value(value, 7274496)) + get_char_from_index(
transform_value(value, 9483264)) + get_char_from_index(
transform_value(value, 19220)) + get_char_from_index(transform_value(value, 235))
elif chunk_length == 2:
value = (chunk[0] << 16) + (chunk[1] << 8)
encoded_string += get_char_from_index(transform_value(value, 7274496)) + get_char_from_index(
transform_value(value, 9483264)) + get_char_from_index(transform_value(value, 19220))
padding = "."
elif chunk_length == 1:
value = chunk[0] << 16
encoded_string += get_char_from_index(transform_value(value, 7274496)) + get_char_from_index(
transform_value(value, 9483264))
padding = ".."
return encoded_string + padding
@staticmethod
def MD5(text: str):
return hashlib.md5(text.encode()).hexdigest()
@staticmethod
def encode_mouse_path(path: list, c: list, s: str):
def preprocess(path: list):
def BFIQ(e):
t = 32767
if not isinstance(e, int):
return e
else:
if t < e:
e = t
elif e < -t:
e = -t
return round(e)
def BGAB(e):
t = ''
n = 0
len(e or [])
while n < len(e) and not t:
if e[n]:
t = e[n][4]
n += 1
if not t:
return e
r = ''
i = ['mouse', 'touch', 'pointer', 'MSPointer']
for s in range(len(i)):
if t.startswith(i[s]):
r = i[s]
_ = list(e)
for a in range(len(_) - 1, -1, -1):
c = _[a]
l = c[0]
if l in ['move', 'down', 'up']:
value = c[4] or ''
if not value.startswith(r):
_.pop(a)
return _
t = 0
n = 0
r = []
s = 0
if len(path) <= 0:
return []
o = None
_ = None
a = BGAB(path)
c = len(a)
for l in range(0 if c < 300 else c - 300, c):
u = a[l]
h = u[0]
if h in ['down', 'move', 'up', 'scroll']:
if not o:
o = u
_ = u
r.append([h, [u[1] - t, u[2] - n], BFIQ(u[3] - s if s else s)])
t = u[1]
n = u[2]
s = u[3]
elif h in ['blur', 'focus', 'unload']:
r.append([h, BFIQ(u[1] - s if s else s)])
s = u[1]
return r
def process(prepared_path: list):
h = {
'move': 0,
'down': 1,
'up': 2,
'scroll': 3,
'focus': 4,
'blur': 5,
'unload': 6,
'unknown': 7
}
def p(e, t):
n = bin(e)[2:]
r = ''
i = len(n) + 1
while i <= t:
i += 1
r += '0'
return r + n
def d(e):
t = []
n = len(e)
r = 0
while r < n:
i = e[r]
s = 0
while True:
if s >= 16:
break
o = r + s + 1
if o >= n:
break
if e[o] != i:
break
s += 1
r += 1 + s
_ = h[i]
if s != 0:
t.append(_ | 8)
t.append(s - 1)
else:
t.append(_)
a = p(n | 32768, 16)
c = ''
for l in range(len(t)):
c += p(t[l], 4)
return a + c
def g(e, tt):
def temp1(e1):
n = len(e)
r = 0
i = []
while r < n:
s = 1
o = e[r]
_ = abs(o)
while True:
if n <= r + s:
break
if e[r + s] != o:
break
if (_ >= 127) or (s >= 127):
break
s += 1
if s > 1:
i.append((49152 if o < 0 else 32768) | s << 7 | _)
else:
i.append(o)
r += s
return i
e = temp1(e)
r = []
i = []
def n(e, t):
return 0 if e == 0 else math.log(e) / math.log(t)
for temp in e:
t = math.ceil(n(abs(temp) + 1, 16))
if t == 0:
t = 1
r.append(p(t - 1, 2))
i.append(p(abs(temp), t * 4))
s = ''.join(r)
o = ''.join(i)
def temp2(t):
return t != 0 and t >> 15 != 1
def temp3(e1):
n = []
def temp(e2):
if temp2(e2):
n.append(e2)
for r in range(len(e1)):
temp(e1[r])
return n
def temp4(t):
if t < 0:
return '1'
else:
return '0'
if tt:
n = []
e1 = temp3(e)
for r in range(len(e1)):
n.append(temp4(e1[r]))
n = ''.join(n)
else:
n = ''
return p(len(e) | 32768, 16) + s + o + n
def u(e):
t = ''
n = len(e) // 6
for r in range(n):
t += '()*,-./0123456789:?@ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz~'[
int(e[6 * r: 6 * (r + 1)], 2)]
return t
t = []
n = []
r = []
i = []
for a in range(len(prepared_path)):
_ = prepared_path[a]
a = len(_)
t.append(_[0])
n.append(_[1] if a == 2 else _[2])
if a == 3:
r.append(_[1][0])
i.append(_[1][1])
c = d(t) + g(n, False) + g(r, True) + g(i, True)
l = len(c)
if l % 6 != 0:
c += p(0, 6 - l % 6)
return u(c)
def postprocess(e, t, n):
i = 0
s = e
o = t[0]
_ = t[2]
a = t[4]
while True:
r = n[i:i + 2]
if not r:
break
i += 2
c = int(r, 16)
l = chr(c)
u = (o * c * c + _ * c + a) % len(e)
s = s[:u] + l + s[u:]
return s
return postprocess(process(preprocess(path)), c, s)
def aes_encrypt(self, content: str):
cipher = Cipher(algorithms.AES(self.aeskey.encode()), modes.CBC(b"0000000000000000"))
encryptor = cipher.encryptor()
padder = padding.PKCS7(128).padder()
padded_data = padder.update(content.encode())
padded_data += padder.finalize()
ct = encryptor.update(padded_data) + encryptor.finalize()
return ct
def get_c_s(self):
o = {
"gt": self.gt,
"challenge": self.challenge,
"offline": False,
"new_captcha": True,
"product": "embed",
"width": "300px",
"https": True,
"protocol": "https://",
}
o.update(self.get_type())
o.update({
"cc": 16,
"ww": True,
"i": "-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1"
})
o = json.dumps(o, separators=(',', ':'))
ct = self.aes_encrypt(o)
s = []
for byte in ct:
s.append(byte)
i = self.encode(s)
r = self.enc_key
w = i + r
params = {
"gt": self.gt,
"challenge": self.challenge,
"lang": "zh-cn",
"pt": 0,
"client_type": "web",
"callback": "geetest_" + str(int(round(time.time() * 1000))),
"w": w
}
resp = self.session.get("https://api.geetest.com/get.php", params=params).text
data = json.loads(resp[22:-1])["data"]
self.c = data["c"]
self.s = data["s"]
return data["c"], data["s"]
def gettype(self):
url = f"https://api.geetest.com/gettype.php?gt={self.gt}&callback=geetest_{str(int(round(time.time() * 1000)))}"
return self.session.get(url).text
def ajax(self):
def transform(e, t, n):
if not t or not n:
return e
o = 0
i = list(e)
s = t[0]
a = t[2]
b = t[4]
while o < len(n):
r = n[o:o + 2]
o += 2
c = int(r, 16)
l = chr(c)
u = (s * c * c + a * c + b) % len(i)
i.insert(u, l)
return ''.join(i)
mouse_path = [
["move", 385, 313, 1724572150164, "pointermove"],
["move", 385, 315, 1724572150166, "pointermove"],
["move", 386, 315, 1724572150174, "pointermove"],
["move", 387, 315, 1724572150182, "pointermove"],
["move", 387, 316, 1724572150188, "pointermove"],
["move", 388, 316, 1724572150204, "pointermove"],
["move", 388, 317, 1724572150218, "pointermove"],
["down", 388, 317, 1724572150586, "pointerdown"],
["focus", 1724572150587],
["up", 388, 317, 1724572150632, "pointerup"]
]
tt = transform(self.encode_mouse_path(mouse_path, self.c, self.s), self.c, self.s)
rp = self.MD5(self.gt + self.challenge + self.s)
temp1 = '''"lang":"zh-cn","type":"fullpage","tt":"%s","light":"DIV_0","s":"c7c3e21112fe4f741921cb3e4ff9f7cb","h":"321f9af1e098233dbd03f250fd2b5e21","hh":"39bd9cad9e425c3a8f51610fd506e3b3","hi":"09eb21b3ae9542a9bc1e8b63b3d9a467","vip_order":-1,"ct":-1,"ep":{"v":"9.1.9-dbjg5z","te":false,"me":true,"ven":"Google Inc. (Intel)","ren":"ANGLE (Intel, Intel(R) Iris(R) Xe Graphics (0x0000A7A0) Direct3D11 vs_5_0 ps_5_0, D3D11)","fp":["scroll",0,1602,1724571628498,null],"lp":["up",386,217,1724571629854,"pointerup"],"em":{"ph":0,"cp":0,"ek":"11","wd":1,"nt":0,"si":0,"sc":0},"tm":{"a":1724571567311,"b":1724571567549,"c":1724571567562,"d":0,"e":0,"f":1724571567312,"g":1724571567312,"h":1724571567312,"i":1724571567317,"j":1724571567423,"k":1724571567330,"l":1724571567423,"m":1724571567545,"n":1724571567547,"o":1724571567569,"p":1724571568259,"q":1724571568259,"r":1724571568261,"s":1724571570378,"t":1724571570378,"u":1724571570380},"dnf":"dnf","by":0},"passtime":1600,"rp":"%s",''' % (
tt, rp)
r = "{" + temp1 + '"captcha_token":"1198034057","du6o":"eyjf7nne"}'
ct = self.aes_encrypt(r)
s = [byte for byte in ct]
w = self.encode(s)
params = {
"gt": self.gt,
"challenge": self.challenge,
"lang": "zh-cn",
"pt": 0,
"client_type": "web",
"callback": "geetest_" + str(int(round(time.time() * 1000))),
"w": w
}
resp = self.session.get("https://api.geetest.com/ajax.php", params=params).text
return json.loads(resp[22:-1])["data"]
def get_pic(self):
params = {
"is_next": "true",
"type": "click",
"gt": self.gt,
"challenge": self.challenge,
"lang": "zh-cn",
"https": "true",
"protocol": "https://",
"offline": "false",
"product": "float",
"api_server": "api.geevisit.com",
"isPC": True,
"autoReset": True,
"width": "100%",
"callback": "geetest_" + str(int(round(time.time() * 1000))),
}
resp = self.session.get("https://api.geevisit.com/get.php", params=params).text
data = json.loads(resp[22:-1])["data"]
self.pic_path = data["pic"]
pic_url = "https://" + data["resource_servers"][0][:-1] + data["pic"]
pic_data = self.session.get(pic_url).content
pic_name = data["pic"].split("/")[-1]
with open(PATH.join(validate_path,pic_name),'wb+') as f:
f.write(pic_data)
return pic_data,pic_name
def verify(self, points: list):
u = self.enc_key
o = {
"lang": "zh-cn",
"passtime": 1600,
"a": ",".join(points),
"pic": self.pic_path,
"tt": self.encode_mouse_path(self.mouse_path, self.c, self.s),
"ep": {
"ca": [{"x": 524, "y": 209, "t": 0, "dt": 1819}, {"x": 558, "y": 299, "t": 0, "dt": 428},
{"x": 563, "y": 95, "t": 0, "dt": 952}, {"x": 670, "y": 407, "t": 3, "dt": 892}],
"v": '3.1.0',
"$_FB": False,
"me": True,
"tm": {"a": 1724585496403, "b": 1724585496605, "c": 1724585496613, "d": 0, "e": 0, "f": 1724585496404,
"g": 1724585496404, "h": 1724585496404, "i": 1724585496404, "j": 1724585496404, "k": 0,
"l": 1724585496413, "m": 1724585496601, "n": 1724585496603, "o": 1724585496618,
"p": 1724585496749, "q": 1724585496749, "r": 1724585496751, "s": 1724585498068,
"t": 1724585498068, "u": 1724585498069}
},
"h9s9": "1816378497",
}
o["rp"] = self.MD5(self.gt + self.challenge + str(o["passtime"]))
o = json.dumps(o, separators=(',', ':'))
ct = self.aes_encrypt(o)
s = []
for byte in ct:
s.append(byte)
p = self.encode(s)
w = p + u
params = {
"gt": self.gt,
"challenge": self.challenge,
"lang": "zh-cn",
"pt": 0,
"client_type": "web",
"w": w
}
resp = self.session.get("https://api.geevisit.com/ajax.php", params=params).text
return resp[1:-1]

139
crop_image.py Normal file
View File

@@ -0,0 +1,139 @@
from PIL import Image, ImageFont, ImageDraw, ImageOps
from io import BytesIO
import cv2
import numpy as np
import os
current_path = os.getcwd()
validate_path = os.path.join(current_path,'img_2_val')#要验证的图片暂存
save_path = os.path.join(current_path,'img_saved')#存放历史图片,留作做数据集以待标记
save_pass_path = os.path.join(save_path,'img_pass')#校验失败的图片,可能是轨迹有误,不一定是分类错误
save_fail_path = os.path.join(save_path,'img_fail')#校验成功的图片,但有可能有个别分类错误
os.makedirs(validate_path,exist_ok=True)
os.makedirs(save_path,exist_ok=True)
os.makedirs(save_pass_path,exist_ok=True)
os.makedirs(save_fail_path,exist_ok=True)
def draw_points_on_image(bg_image, answer):
# 将背景图片转换为OpenCV格式
bg_image_cv = cv2.imdecode(np.frombuffer(bg_image, dtype=np.uint8), cv2.IMREAD_COLOR)
# 定义九宫格的大小和偏移量
grid_width = 100
grid_height = 86
offset_x = 45
offset_y = 38
for i, (row, col) in enumerate(answer):
x = offset_x + (col-1) * grid_width
y = offset_y + (row-1) * grid_height
cv2.circle(bg_image_cv, (x, y), 10, (0, 0, 255), -1)
cv2.imwrite('./img_2_val/predicted.jpg', bg_image_cv)#推理结果
def convert_png_to_jpg(png_bytes: bytes) -> bytes:
# 将传入的 bytes 转换为图像对象
png_image = Image.open(BytesIO(png_bytes))
# 创建一个 BytesIO 对象,用于存储输出的 JPG 数据
output_bytes = BytesIO()
# 检查图像是否具有透明度通道 (RGBA)
if png_image.mode == 'RGBA':
# 创建白色背景
white_bg = Image.new("RGB", png_image.size, (255, 255, 255))
# 将 PNG 图像粘贴到白色背景上,透明部分用白色填充
white_bg.paste(png_image, (0, 0), png_image)
jpg_image = white_bg
else:
# 如果图像没有透明度,直接转换为 RGB 模式
jpg_image = png_image.convert("RGB")
# 将转换后的图像保存为 JPG 格式到 BytesIO 对象
jpg_image.save(output_bytes, format="JPEG")
# 返回保存后的 JPG 图像的 bytes
return output_bytes.getvalue()
def crop_image(image_bytes, coordinates):
img = Image.open(BytesIO(image_bytes))
width, height = img.size
grid_width = width // 3
grid_height = height // 3
cropped_images = []
for coord in coordinates:
y, x = coord
left = (x - 1) * grid_width
upper = (y - 1) * grid_height
right = left + grid_width
lower = upper + grid_height
box = (left, upper, right, lower)
cropped_img = img.crop(box)
cropped_images.append(cropped_img)
return cropped_images
def crop_image_v3(image_bytes):
coordinates = [
[[0, 0], [112, 112]],
[[116, 0], [228, 112]],
[[232, 0], [344, 112]],#第一行
[[0, 116], [112, 228]],
[[116, 116], [228, 228]],
[[232, 116], [344, 228]],#第二行
[[0, 232], [112, 344]],
[[116, 232], [228, 344]],
[[232, 232], [344, 344]],#第三行
[[2, 344], [42, 384]] #要验证的
]
image = Image.open(BytesIO(image_bytes))
image = Image.fromarray(cv2.cvtColor(np.array(image), cv2.COLOR_RGBA2RGB))
imageNew = Image.new('RGB', (300,261),(0,0,0))
images = []
for i, (start_point, end_point) in enumerate(coordinates):
x1, y1 = start_point
x2, y2 = end_point
# 切割图像
cropped_image = image.crop((x1, y1, x2, y2))
images.append(cropped_image)
# 保存切割后的图像
output_path = os.path.join(validate_path,f'cropped_{i}.jpg')
cropped_image.save(output_path)
for i in range(3):
imageNew.paste(images[i].resize((100,86)), (i*100, 0, (i+1)*100, 86))
imageNew.paste(images[i+3].resize((100,86)), (i*100, 86, (i+1)*100, 172))
imageNew.paste(images[i+6].resize((100,86)), (i*100,172, (i+1)*100, 258))
imageNew.save(os.path.join(validate_path,f'nine.jpg') )
if __name__ == "__main__":
# v4测试代码
# os.makedirs(os.path.join(current_path,'image_test'),exist_ok=True)
# # 切割顺序,这里是从左到右,从上到下[x,y]
# coordinates = [
# [1, 1],
# [1, 2],
# [1, 3],
# [2, 1],
# [2, 2],
# [2, 3],
# [3, 1],
# [3, 2],
# [3, 3],
# ]
# with open("./image_test/bg.jpg", "rb") as rb:
# bg_img = rb.read()
# cropped_images = crop_image(bg_img, coordinates)
# # 一个个保存下来
# for j, img_crop in enumerate(cropped_images):
# img_crop.save(f"./image_test/bg{j}.jpg")
# # 图标格式转换
# with open("./image_test/icon.png", "rb") as rb:
# icon_img = rb.read()
# icon_img_jpg = convert_png_to_jpg(icon_img)
# with open("./image_test/icon.jpg", "wb") as wb:
# wb.write(icon_img_jpg)
# V3测试代码
pic = "./img_saved/f105965489de434e930fa1ef8a5bcd9f.jpg"
print("推理图片为:",pic)
with open(pic, "rb") as f:
img = f.read()
crop_image_v3(img)

BIN
img_2_val/nine.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 17 KiB

BIN
img_2_val/predicted.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1023 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 933 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 984 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 949 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 KiB

66
main.py Normal file
View File

@@ -0,0 +1,66 @@
import json
import time
import random
from crack import Crack
from crop_image import crop_image_v3,save_path,save_fail_path,save_pass_path,validate_path
import httpx
from fastapi import FastAPI,Query
from fastapi.responses import JSONResponse
import shutil
import os
port = 9645
# api
app = FastAPI()
@app.get("/pass_nine")
def get_pic(gt: str = Query(...), challenge: str = Query(...), point: str = Query(default=None)):
print(f"开始获取:\ngt:{gt}\nchallenge:{challenge}")
t = time.time()
crack = Crack(gt, challenge)
crack.gettype()
crack.get_c_s()
time.sleep(random.uniform(0.4,0.6))
crack.ajax()
pic_content,pic_name = crack.get_pic()
crop_image_v3(pic_content)
with open(f"{validate_path}/cropped_9.jpg", "rb") as rb:
icon_image = rb.read()
with open(f"{validate_path}/nine.jpg", "rb") as rb:
bg_image = rb.read()
result_list = predict_onnx(icon_image, bg_image, point)
point_list = [f"{col}_{row}" for row, col in result_list]
wait_time = 4.0 - (time.time() - t)
time.sleep(wait_time)
result = json.loads(crack.verify(point_list))
shutil.move(os.path.join(validate_path,pic_name),os.path.join(save_path,pic_name))
if 'validate' in result['data']:
path_2_save = os.path.join(save_pass_path,pic_name.split('.')[0])
else:
path_2_save = os.path.join(save_fail_path,pic_name.split('.')[0])
os.makedirs(path_2_save,exist_ok=True)
for pic in os.listdir(validate_path):
if pic.startswith('cropped'):
shutil.move(os.path.join(validate_path,pic),os.path.join(path_2_save,pic))
total_time = time.time() - t
print(f"总计耗时(含等待{wait_time}s): {total_time}\n{result}")
return JSONResponse(content=result)
if __name__ == "__main__":
from predict import predict_onnx
import uvicorn
print(f"{' '*10}api: http://127.0.0.1:{port}/pass_nine{' '*10}")
print(f"{' '*10}api所需参数gt、challenge、point(可选){' '*10}")
uvicorn.run(app,port=port)

2337
mousepath.json Normal file

File diff suppressed because it is too large Load Diff

159
predict.py Normal file
View File

@@ -0,0 +1,159 @@
import os
import numpy as np
from train import MyResNet18, data_transform
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image
import torch
import time
import cv2
from PIL import Image
from io import BytesIO
import onnxruntime as ort
def predict(icon_image, bg_image):
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', 'resnet18_38_0.021147585306924.pth')
coordinates = [
[1, 1],
[1, 2],
[1, 3],
[2, 1],
[2, 2],
[2, 3],
[3, 1],
[3, 2],
[3, 3],
]
target_images = []
target_images.append(data_transform(Image.open(BytesIO(icon_image))))
bg_images = crop_image(bg_image, coordinates)
for bg_image in bg_images:
target_images.append(data_transform(bg_image))
start = time.time()
model = MyResNet18(num_classes=91) # 这里的类别数要与训练时一致
model.load_state_dict(torch.load(model_path))
model.eval()
print("加载模型,耗时:", time.time() - start)
start = time.time()
target_images = torch.stack(target_images, dim=0)
target_outputs = model(target_images)
scores = []
for i, out_put in enumerate(target_outputs):
if i == 0:
# 增加维度,以便于计算
target_output = out_put.unsqueeze(0)
else:
similarity = torch.nn.functional.cosine_similarity(
target_output, out_put.unsqueeze(0)
)
scores.append(similarity.cpu().item())
# 从左到右,从上到下,依次为每张图片的置信度
print(scores)
# 对数组进行排序,保持下标
indexed_arr = list(enumerate(scores))
sorted_arr = sorted(indexed_arr, key=lambda x: x[1], reverse=True)
# 提取最大三个数及其下标
largest_three = sorted_arr[:3]
print(largest_three)
print("识别完成,耗时:", time.time() - start)
# 加载onnx模型
start = time.time()
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', 'resnet18.onnx')
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
print("加载模型,耗时:", time.time() - start)
def predict_onnx(icon_image, bg_image, point = None):
coordinates = [
[1, 1],
[1, 2],
[1, 3],
[2, 1],
[2, 2],
[2, 3],
[3, 1],
[3, 2],
[3, 3],
]
def cosine_similarity(vec1, vec2):
# 将输入转换为 NumPy 数组
vec1 = np.array(vec1)
vec2 = np.array(vec2)
# 计算点积
dot_product = np.dot(vec1, vec2)
# 计算向量的范数
norm_vec1 = np.linalg.norm(vec1)
norm_vec2 = np.linalg.norm(vec2)
# 计算余弦相似度
similarity = dot_product / (norm_vec1 * norm_vec2)
return similarity
def data_transforms(image):
image = image.resize((224, 224))
image = Image.fromarray(cv2.cvtColor(np.array(image), cv2.COLOR_RGBA2RGB))
image_array = np.array(image)
image_array = image_array.astype(np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
image_array = (image_array - mean) / std
image_array = np.transpose(image_array, (2, 0, 1))
# image_array = np.expand_dims(image_array, axis=0)
return image_array
target_images = []
target_images.append(data_transforms(Image.open(BytesIO(icon_image))))
bg_images = crop_image(bg_image, coordinates)
for one in bg_images:
target_images.append(data_transforms(one))
start = time.time()
outputs = session.run(None, {input_name: target_images})[0]
scores = []
for i, out_put in enumerate(outputs):
if i == 0:
target_output = out_put
else:
similarity = cosine_similarity(target_output, out_put)
scores.append(similarity)
# 从左到右,从上到下,依次为每张图片的置信度
# print(scores)
# 对数组进行排序,保持下标
indexed_arr = list(enumerate(scores))
sorted_arr = sorted(indexed_arr, key=lambda x: x[1], reverse=True)
# 提取最大三个数及其下标
if point == None:
largest_three = sorted_arr[:3]
answer = [coordinates[i[0]] for i in largest_three]
# 基于分数判断
else:
answer = [one[0] for one in sorted_arr if one[1] > point]
print(f"识别完成{answer},耗时: {time.time() - start}")
draw_points_on_image(bg_image, answer)
return answer
if __name__ == "__main__":
icon_path = "img_2_val/cropped_9.jpg"
bg_path = "img_2_val/nine.jpg"
with open(icon_path, "rb") as rb:
if icon_path.endswith('.png'):
icon_image = convert_png_to_jpg(rb.read())
else:
icon_image = rb.read()
with open(bg_path, "rb") as rb:
bg_image = rb.read()
predict_onnx(icon_image, bg_image)

11
requirements.txt Normal file
View File

@@ -0,0 +1,11 @@
httpx
cryptography
onnxruntime
opencv-python
numpy
torch
torchvision
Pillow
matplotlib
tqdm
shutil

119
train.py Normal file
View File

@@ -0,0 +1,119 @@
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import os
os.makedirs(os.path.join(os.getcwd(),'model'),exist_ok=True)
# 定义数据转换
data_transform = transforms.Compose(
[
transforms.Resize((224, 224)), # 调整图像大小
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize(
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
), # 标准化图像
]
)
# 定义数据集
class CustomDataset:
def __init__(self, data_dir):
self.dataset = ImageFolder(root=data_dir, transform=data_transform)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
image, label = self.dataset[idx]
return image, label
class MyResNet18(torch.nn.Module):
def __init__(self, num_classes):
super(MyResNet18, self).__init__()
self.resnet = torchvision.models.resnet18(pretrained=True)
self.resnet.fc = nn.Linear(512, num_classes) # 修改这里的输入大小为512
def forward(self, x):
return self.resnet(x)
def train(epoch):
print("judge the cuda: " + str(torch.version.cuda))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("this train use devices: " + str(device))
data_dir = "dataset"
# 自定义数据集实例
custom_dataset = CustomDataset(data_dir)
# 数据加载器
batch_size = 64
data_loader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=True)
# 初始化模型 num_classes就是目录下的子文件夹数目每个子文件夹对应一个分类模型输出的向量长度也是这个长度
model = MyResNet18(num_classes=91)
model.to(device)
# 损失函数
criterion = torch.nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
epoch_losses = []
# 训练模型
for i in range(epoch):
losses = []
# 迭代器进度条
data_loader_tqdm = tqdm(data_loader)
epoch_loss = 0
for inputs, labels in data_loader_tqdm:
# 将输入数据和标签传输到指定的计算设备(如 GPU 或 CPU
inputs, labels = inputs.to(device), labels.to(device)
# 梯度更新之前将所有模型参数的梯度置为零,防止梯度累积
optimizer.zero_grad()
# 前向传播:将输入数据传入模型,计算输出
outputs = model(inputs)
# 根据模型的输出和实际标签计算损失值
loss = criterion(outputs, labels)
# 将当前批次的损失值记录到 losses 列表中,以便后续计算平均损失
losses.append(loss.item())
epoch_loss = np.mean(losses)
data_loader_tqdm.set_description(
f"This epoch is {str(i + 1)} and it's loss is {loss.item()}, average loss {epoch_loss}"
)
# 反向传播:根据当前损失值计算模型参数的梯度
loss.backward()
# 使用优化器更新模型参数,根据梯度调整模型参数
optimizer.step()
epoch_losses.append(epoch_loss)
# 每过一个batch就保存一次模型
torch.save(model.state_dict(), f'model/resnet18_{str(i + 1)}_{epoch_loss}.pth')
# loss 变化绘制代码
data = np.array(epoch_losses)
plt.figure(figsize=(10, 6))
plt.plot(data)
plt.title(f"{epoch} epoch loss change")
plt.xlabel("epoch")
plt.ylabel("Loss")
# 显示图像
plt.show()
print(f"completed. Model saved.")
if __name__ == '__main__':
train(40)