2024/9/7 更新代码

This commit is contained in:
taisuii
2024-09-07 11:31:00 +08:00
parent 592ef02776
commit d444790657
16 changed files with 700 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
/development/dataset
/development/model
/development/test

8
.idea/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

11
.idea/ClassificationCaptchaOcr.iml generated Normal file
View File

@@ -0,0 +1,11 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/development/dataset" />
<excludeFolder url="file://$MODULE_DIR$/development/model" />
</content>
<orderEntry type="jdk" jdkName="torch" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

4
.idea/misc.xml generated Normal file
View File

@@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="torch" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated Normal file
View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/ClassificationCaptchaOcr.iml" filepath="$PROJECT_DIR$/.idea/ClassificationCaptchaOcr.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml generated Normal file
View File

@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

0
development/__init__.py Normal file
View File

74
development/crop_image.py Normal file
View File

@@ -0,0 +1,74 @@
from PIL import Image, ImageFont, ImageDraw, ImageOps
from io import BytesIO
def convert_png_to_jpg(png_bytes: bytes) -> bytes:
# 将传入的 bytes 转换为图像对象
png_image = Image.open(BytesIO(png_bytes))
# 创建一个 BytesIO 对象,用于存储输出的 JPG 数据
output_bytes = BytesIO()
# 检查图像是否具有透明度通道 (RGBA)
if png_image.mode == 'RGBA':
# 创建白色背景
white_bg = Image.new("RGB", png_image.size, (255, 255, 255))
# 将 PNG 图像粘贴到白色背景上,透明部分用白色填充
white_bg.paste(png_image, (0, 0), png_image)
jpg_image = white_bg
else:
# 如果图像没有透明度,直接转换为 RGB 模式
jpg_image = png_image.convert("RGB")
# 将转换后的图像保存为 JPG 格式到 BytesIO 对象
jpg_image.save(output_bytes, format="JPEG")
# 返回保存后的 JPG 图像的 bytes
return output_bytes.getvalue()
def crop_image(image_bytes, coordinates):
img = Image.open(BytesIO(image_bytes))
width, height = img.size
grid_width = width // 3
grid_height = height // 3
cropped_images = []
for coord in coordinates:
y, x = coord
left = (x - 1) * grid_width
upper = (y - 1) * grid_height
right = left + grid_width
lower = upper + grid_height
box = (left, upper, right, lower)
cropped_img = img.crop(box)
cropped_images.append(cropped_img)
return cropped_images
if __name__ == "__main__":
# 切割顺序,这里是从左到右,从上到下[x,y]
coordinates = [
[1, 1],
[1, 2],
[1, 3],
[2, 1],
[2, 2],
[2, 3],
[3, 1],
[3, 2],
[3, 3],
]
with open("./image_test/bg.jpg", "rb") as rb:
bg_img = rb.read()
cropped_images = crop_image(bg_img, coordinates)
# 一个个保存下来
for j, img_crop in enumerate(cropped_images):
img_crop.save(f"./image_test/bg{j}.jpg")
# 图标格式转换
with open("./image_test/icon.png", "rb") as rb:
icon_img = rb.read()
icon_img_jpg = convert_png_to_jpg(icon_img)
with open("./image_test/icon.jpg", "wb") as wb:
wb.write(icon_img_jpg)

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.4 KiB

148
development/predict.py Normal file
View File

@@ -0,0 +1,148 @@
import os
import numpy as np
from development.resnet18 import MyResNet18, data_transform
from development.crop_image import crop_image, convert_png_to_jpg
import torch
import time
from PIL import Image
from io import BytesIO
import onnxruntime as ort
def predict(icon_image, bg_image):
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', 'resnet18_38_0.021147585306924.pth')
coordinates = [
[1, 1],
[1, 2],
[1, 3],
[2, 1],
[2, 2],
[2, 3],
[3, 1],
[3, 2],
[3, 3],
]
target_images = []
target_images.append(data_transform(Image.open(BytesIO(icon_image))))
bg_images = crop_image(bg_image, coordinates)
for bg_image in bg_images:
target_images.append(data_transform(bg_image))
start = time.time()
model = MyResNet18(num_classes=91) # 这里的类别数要与训练时一致
model.load_state_dict(torch.load(model_path))
model.eval()
print("加载模型,耗时:", time.time() - start)
start = time.time()
target_images = torch.stack(target_images, dim=0)
target_outputs = model(target_images)
scores = []
for i, out_put in enumerate(target_outputs):
if i == 0:
# 增加维度,以便于计算
target_output = out_put.unsqueeze(0)
else:
similarity = torch.nn.functional.cosine_similarity(
target_output, out_put.unsqueeze(0)
)
scores.append(similarity.cpu().item())
# 从左到右,从上到下,依次为每张图片的置信度
print(scores)
# 对数组进行排序,保持下标
indexed_arr = list(enumerate(scores))
sorted_arr = sorted(indexed_arr, key=lambda x: x[1], reverse=True)
# 提取最大三个数及其下标
largest_three = sorted_arr[:3]
print(largest_three)
print("识别完成,耗时:", time.time() - start)
# 加载onnx模型
start = time.time()
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'model', 'resnet18.onnx')
session = ort.InferenceSession(model_path)
input_name = session.get_inputs()[0].name
print("加载模型,耗时:", time.time() - start)
def predict_onnx(icon_image, bg_image):
coordinates = [
[1, 1],
[1, 2],
[1, 3],
[2, 1],
[2, 2],
[2, 3],
[3, 1],
[3, 2],
[3, 3],
]
def cosine_similarity(vec1, vec2):
# 将输入转换为 NumPy 数组
vec1 = np.array(vec1)
vec2 = np.array(vec2)
# 计算点积
dot_product = np.dot(vec1, vec2)
# 计算向量的范数
norm_vec1 = np.linalg.norm(vec1)
norm_vec2 = np.linalg.norm(vec2)
# 计算余弦相似度
similarity = dot_product / (norm_vec1 * norm_vec2)
return similarity
def data_transforms(image):
image = image.resize((224, 224))
image_array = np.array(image)
image_array = image_array.astype(np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
image_array = (image_array - mean) / std
image_array = np.transpose(image_array, (2, 0, 1))
# image_array = np.expand_dims(image_array, axis=0)
return image_array
target_images = []
target_images.append(data_transforms(Image.open(BytesIO(icon_image))))
bg_images = crop_image(bg_image, coordinates)
for bg_image in bg_images:
target_images.append(data_transforms(bg_image))
start = time.time()
outputs = session.run(None, {input_name: target_images})[0]
scores = []
for i, out_put in enumerate(outputs):
if i == 0:
target_output = out_put
else:
similarity = cosine_similarity(target_output, out_put)
scores.append(similarity)
# 从左到右,从上到下,依次为每张图片的置信度
# print(scores)
# 对数组进行排序,保持下标
indexed_arr = list(enumerate(scores))
sorted_arr = sorted(indexed_arr, key=lambda x: x[1], reverse=True)
# 提取最大三个数及其下标
largest_three = sorted_arr[:3]
answer = [coordinates[i[0]] for i in largest_three]
print(f"识别完成{answer},耗时: {time.time() - start}")
return answer
if __name__ == "__main__":
with open("image_test/icon.png", "rb") as rb:
icon_image = convert_png_to_jpg(rb.read())
with open("image_test/bg.jpg", "rb") as rb:
bg_image = rb.read()
predict_onnx(icon_image, bg_image)

17
development/pth2onnx.py Normal file
View File

@@ -0,0 +1,17 @@
from resnet18 import MyResNet18
import torch
def convert():
# 加载 PyTorch 模型
model_path = "model/resnet18_38_0.021147585306924.pth"
model = MyResNet18(num_classes=91)
model.load_state_dict(torch.load(model_path))
model.eval()
# 生成一个示例输入
dummy_input = torch.randn(10, 3, 224, 224)
# 将模型转换为 ONNX 格式
torch.onnx.export(model, dummy_input, "model/resnet18.onnx", verbose=True)
if __name__ == '__main__':
convert()

117
development/resnet18.py Normal file
View File

@@ -0,0 +1,117 @@
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
# 定义数据转换
data_transform = transforms.Compose(
[
transforms.Resize((224, 224)), # 调整图像大小
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize(
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
), # 标准化图像
]
)
# 定义数据集
class CustomDataset:
def __init__(self, data_dir):
self.dataset = ImageFolder(root=data_dir, transform=data_transform)
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
image, label = self.dataset[idx]
return image, label
class MyResNet18(torch.nn.Module):
def __init__(self, num_classes):
super(MyResNet18, self).__init__()
self.resnet = torchvision.models.resnet18(pretrained=True)
self.resnet.fc = nn.Linear(512, num_classes) # 修改这里的输入大小为512
def forward(self, x):
return self.resnet(x)
def train(epoch):
print("judge the cuda: " + str(torch.version.cuda))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("this train use devices: " + str(device))
data_dir = "dataset"
# 自定义数据集实例
custom_dataset = CustomDataset(data_dir)
# 数据加载器
batch_size = 64
data_loader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=True)
# 初始化模型 num_classes就是目录下的子文件夹数目每个子文件夹对应一个分类模型输出的向量长度也是这个长度
model = MyResNet18(num_classes=91)
model.to(device)
# 损失函数
criterion = torch.nn.CrossEntropyLoss()
# 优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
epoch_losses = []
# 训练模型
for i in range(epoch):
losses = []
# 迭代器进度条
data_loader_tqdm = tqdm(data_loader)
epoch_loss = 0
for inputs, labels in data_loader_tqdm:
# 将输入数据和标签传输到指定的计算设备(如 GPU 或 CPU
inputs, labels = inputs.to(device), labels.to(device)
# 梯度更新之前将所有模型参数的梯度置为零,防止梯度累积
optimizer.zero_grad()
# 前向传播:将输入数据传入模型,计算输出
outputs = model(inputs)
# 根据模型的输出和实际标签计算损失值
loss = criterion(outputs, labels)
# 将当前批次的损失值记录到 losses 列表中,以便后续计算平均损失
losses.append(loss.item())
epoch_loss = np.mean(losses)
data_loader_tqdm.set_description(
f"This epoch is {str(i + 1)} and it's loss is {loss.item()}, average loss {epoch_loss}"
)
# 反向传播:根据当前损失值计算模型参数的梯度
loss.backward()
# 使用优化器更新模型参数,根据梯度调整模型参数
optimizer.step()
epoch_losses.append(epoch_loss)
# 每过一个batch就保存一次模型
torch.save(model.state_dict(), f'model/resnet18_{str(i + 1)}_{epoch_loss}.pth')
# loss 变化绘制代码
data = np.array(epoch_losses)
plt.figure(figsize=(10, 6))
plt.plot(data)
plt.title(f"{epoch} epoch loss change")
plt.xlabel("epoch")
plt.ylabel("Loss")
# 显示图像
plt.show()
print(f"completed. Model saved.")
if __name__ == '__main__':
train(40)

3
development/test/test.py Normal file
View File

@@ -0,0 +1,3 @@
import torch
print(torch.version.cuda) # 检查 PyTorch 是否包含 CUDA 支持
print(torch.cuda.is_available()) # 检查 CUDA 是否可用

295
main.py Normal file
View File

@@ -0,0 +1,295 @@
import numpy as np
from fake_useragent import UserAgent
from flask import Flask, request
import uuid
import re
import json
from loguru import logger
import requests
import time
import random
from binascii import b2a_hex, a2b_hex
import rsa
import hashlib
from Crypto.Cipher import AES
import execjs
from matplotlib import pyplot as plt
from development.predict import predict_onnx
from development.crop_image import convert_png_to_jpg
class Encrypt():
def rsa_encrypt(self, msg):
e = '010001'
e = int(e, 16)
n = '00C1E3934D1614465B33053E7F48EE4EC87B14B95EF88947713D25EECBFF7E74C7977D02DC1D9451F79DD5D1C10C29ACB6A9B4D6FB7D0A0279B6719E1772565F09AF627715919221AEF91899CAE08C0D686D748B20A3603BE2318CA6BC2B59706592A9219D0BF05C9F65023A21D2330807252AE0066D59CEEFA5F2748EA80BAB81'
n = int(n, 16)
pub_key = rsa.PublicKey(e=e, n=n)
return b2a_hex(rsa.encrypt(bytes(msg.encode()), pub_key))
def aes_encrypt(self, key, iv, content):
def pkcs7padding(text):
"""明文使用PKCS7填充 """
bs = 16
length = len(text)
bytes_length = len(text.encode('utf-8'))
padding_size = length if (bytes_length == length) else bytes_length
padding = bs - padding_size % bs
padding_text = chr(padding) * padding
self.coding = chr(padding)
return text + padding_text
key = key.encode('utf-8')
iv = iv.encode('utf-8')
""" AES加密 """
cipher = AES.new(key, AES.MODE_CBC, iv)
# 处理明文
content_padding = pkcs7padding(content)
# 加密
encrypt_bytes = cipher.encrypt(content_padding.encode('utf-8'))
# 重新编码
result = b2a_hex(encrypt_bytes).decode()
return result
def get_random_key_16(self):
data = ""
for i in range(4):
data += (format((int((1 + random.random()) * 65536) | 0), "x")[1:])
return data
def get_pow(self, pow_detail, captcha_id, lot_number):
n = pow_detail['hashfunc']
i = pow_detail['version']
r = pow_detail['bits']
s = pow_detail['datetime']
o = ""
a = r % 4
u = r // 4
c = '0' * u
_ = f"{i}|{r}|{n}|{s}|{captcha_id}|{lot_number}|{o}|"
while True:
h = self.get_random_key_16()
l = _ + h
if n == "md5":
p = hashlib.md5(l.encode()).hexdigest()
elif n == "sha1":
p = hashlib.sha1(l.encode()).hexdigest()
elif n == "sha256":
p = hashlib.sha256(l.encode()).hexdigest()
if a == 0:
if p.startswith(c):
return {"pow_msg": _ + h, "pow_sign": p}
else:
if p.startswith(c):
d = int(p[u], 16)
if a == 1:
f = 7
elif a == 2:
f = 3
elif a == 3:
f = 1
if d <= f:
return {"pow_msg": _ + h, "pow_sign": p}
def gt_data_assembly(self, pow_detail, captcha_id, lot_number, dynamic_parameter, userresponse):
pow_data = self.get_pow(pow_detail, captcha_id, lot_number)
e = {
"passtime": random.randint(1500, 4000),
"userresponse": userresponse,
"device_id": "",
"lot_number": lot_number,
"pow_msg": pow_data['pow_msg'],
"pow_sign": pow_data['pow_sign'],
"geetest": "captcha",
"lang": "zh",
"ep": "123",
"biht": "1426265548",
"gee_guard": '',
"em": {"ph": 0, "cp": 0, "ek": "11", "wd": 1, "nt": 0, "si": 0, "sc": 0}
}
e.update(dynamic_parameter)
e = str(e).replace('\'', '"').replace(' ', '')
aes_key = self.get_random_key_16()
rsa_result = str(self.rsa_encrypt(msg=aes_key), 'utf-8')
aes_result = self.aes_encrypt(key=aes_key, iv='0000000000000000', content=e)
w = aes_result + rsa_result
return w
class GEETEST4():
def __init__(self, proxies, captcha_id, risk_type):
self.risk_type = risk_type
if proxies == "no" or proxies == "":
self.proxy = None
else:
proxies = proxies.replace("\n", "").replace("\r", "")
self.proxy = {
"http": f"http://{proxies}",
"https": f"http://{proxies}"
}
self.captcha_id = captcha_id
ua = UserAgent()
self.headers = {
"User-Agent": ua.random,
"Referer": "https://gt4.geetest.com/"
}
self.session = requests.Session()
self.session.headers = self.headers
def get_load(self):
url = "https://gcaptcha4.geetest.com/load"
params = {
"captcha_id": self.captcha_id,
"challenge": uuid.uuid4(),
"client_type": "web",
"risk_type": self.risk_type,
"lang": "zh-cn",
"callback": "geetest_" + str(int(time.time() * 1000))
}
response = self.session.get(url, headers=self.headers, params=params, proxies=self.proxy).text
response = json.loads(re.findall(r"geetest_\d+\((.*?}})\)", response)[0])
self.load = response
self.risk_type = response["data"]["captcha_type"]
def get_dynamic_parameter(self):
url = "https://gcaptcha4.geetest.com/load"
params = {
"captcha_id": "0b2abaab0ad3f4744ab45342a2f3d409",
"challenge": uuid.uuid4(),
"client_type": "web",
"risk_type": "nine",
"lang": "zh-cn",
"callback": "geetest_" + str(int(time.time() * 1000))
}
response = requests.get(url, headers=self.headers, params=params).text
response = json.loads(re.findall(r"geetest_\d+\((.*?}})\)", response)[0])
static_path = response["data"]["static_path"]
gcaptcha_js = "https://static.geetest.com/" + static_path + "/js/gcaptcha4.js"
js = requests.get(gcaptcha_js, headers={
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/86.0.4240.198 Safari/537.36",
"Referer": "https://gt4.geetest.com/",
}).text.split(";Uaaco")[0]
complete_js = """
Uaaco = {};
""" + js + """
function getDynamicParameter() {
return Uaaco.$_AL.$_HIBAt(781);
}
"""
dynamic_parameter = execjs.compile(complete_js).call("getDynamicParameter")
return json.loads(dynamic_parameter)
def get_captcha_img(self):
url = "https://static.geetest.com/"
try:
bg_img_url = url + self.load["data"]["imgs"]
except Exception as e:
bg_img_url = url + self.load["data"]["bg"]
bg_img = requests.get(bg_img_url, headers=self.headers).content
img_urls = []
try:
img_urls = self.load["data"]["ques"]
except:
try:
img_urls.append(self.load["data"]["slice"])
except:
img_urls = []
if img_urls.__len__() == 0:
return bg_img
elif img_urls.__len__() == 1:
return bg_img, self.session.get(url + img_urls[0], headers=self.headers).content
else:
img1 = self.session.get(url + img_urls[0], headers=self.headers).content
img2 = self.session.get(url + img_urls[1], headers=self.headers).content
img3 = self.session.get(url + img_urls[2], headers=self.headers).content
return bg_img, [img1, img2, img3]
def ocr(self):
bg_img, icon_img = self.get_captcha_img()
answer = predict_onnx(convert_png_to_jpg(icon_img), bg_img)
return answer
def verify(self):
self.get_load()
self.dynamic_parameter = self.get_dynamic_parameter()
url = "https://gcaptcha4.geetest.com/verify"
pow_detail = self.load["data"]["pow_detail"]
lot_number = self.load["data"]["lot_number"]
payload = self.load["data"]["payload"]
process_token = self.load["data"]["process_token"]
ocr_result = self.ocr()
w = Encrypt().gt_data_assembly(pow_detail, self.captcha_id, lot_number, self.dynamic_parameter, ocr_result)
params = {
"captcha_id": self.captcha_id,
"client_type": "web",
"lot_number": lot_number,
"risk_type": self.risk_type,
"payload": payload,
"process_token": process_token,
"payload_protocol": "1",
"pt": "1",
"w": w,
"callback": "geetest_" + str(int(time.time() * 1000))
}
response = self.session.get(url, headers=self.headers, params=params, proxies=self.proxy).text
response = json.loads(re.findall(r"geetest_\d+\((.*?}})\)", response)[0])
if response["data"]["result"] == "success":
logger.success(json.dumps(response, ensure_ascii=False))
return True
else:
logger.error(json.dumps(response, ensure_ascii=False))
return False
app = Flask(__name__)
@app.route("/geetest4", methods=["GET", "POST"])
def geetest4():
risk_type = "nine"
captcha_id = "435d94a5f5b138efd5dc9f9ffc7f5621"
proxy = ""
Gt4 = GEETEST4(proxy, captcha_id, risk_type)
return Gt4.verify()
def test():
risk_type = "nine"
captcha_id = "54088bb07d2df3c46b79f80300b0abbe"
proxy = ""
Gt4 = GEETEST4(proxy, captcha_id, risk_type)
return Gt4.verify()
if __name__ == '__main__':
# app.run(host="0.0.0.0", port=9797)
spendtime = []
success = 0
for i in range(50):
start = time.time()
if test():
success = success + 1
spendtime.append(time.time() - start)
data = np.array(spendtime)
plt.figure(figsize=(10, 6))
plt.plot(data)
plt.title(f"verify spend time, average spend: {np.mean(spendtime)} and Success rate: {str(success * 2)}/100")
plt.ylabel("time")
# 显示图像
plt.show()