11.1
63
README.md
Normal file
@@ -0,0 +1,63 @@
|
||||
# 九宫格测试代码
|
||||
|
||||
## **本项目仅供学习交流使用,请勿用于商业用途,否则后果自负。**
|
||||
|
||||
## **本项目仅供学习交流使用,请勿用于商业用途,否则后果自负。**
|
||||
|
||||
## **本项目仅供学习交流使用,请勿用于商业用途,否则后果自负。**
|
||||
|
||||
## 参考项目
|
||||
|
||||
模型及V4数据集:https://github.com/taisuii/ClassificationCaptchaOcr
|
||||
|
||||
api:https://github.com/ravizhan/geetest-v3-click-crack
|
||||
|
||||
## 运行步骤
|
||||
|
||||
### 1.安装依赖
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2.自行准备数据集,V3和V4有区别
|
||||
|
||||
- 数据集详情参考上面标注的项目,但是上面项目是V4数据集,V3没有demo,自行发挥吧,用V4正确率有点感人,或许可以试试别的模型看看能不能泛化
|
||||
|
||||
- 如果要切V3的图用crop_image.py的crop_image_v3,切V4则使用crop_image,自行编写切图脚本
|
||||
|
||||
|
||||
### 3.训练模型
|
||||
|
||||
训练运行 `python train.py`
|
||||
|
||||
### 4.模型转换为onnx
|
||||
|
||||
运行 `python convert.py`(自行进去修改需要转换的模型,一般是选loss小的)
|
||||
|
||||
### 5.启动fastapi服务
|
||||
|
||||
运行 `python main.py`
|
||||
|
||||
### 6.api调用
|
||||
|
||||
python调用如:
|
||||
|
||||
```python
|
||||
import httpx
|
||||
|
||||
res = httpx.get("http://127.0.0.1:9645/pass_nine",params={'gt':gt,'challenge':challenge},timeout=10)
|
||||
|
||||
datas = res.json()['data']
|
||||
|
||||
if datas['result'] == 'success':
|
||||
|
||||
return datas['validate']
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
0
__init__.py
Normal file
17
convert.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from train import MyResNet18
|
||||
import torch
|
||||
|
||||
def convert():
|
||||
# 加载 PyTorch 模型
|
||||
model_path = "model/resnet18_39_0.01445627337038193.pth"
|
||||
model = MyResNet18(num_classes=91)
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
model.eval()
|
||||
# 生成一个示例输入
|
||||
dummy_input = torch.randn(10, 3, 224, 224)
|
||||
# 将模型转换为 ONNX 格式
|
||||
torch.onnx.export(model, dummy_input, "model/resnet18.onnx", verbose=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
convert()
|
||||
486
crack.py
Normal file
@@ -0,0 +1,486 @@
|
||||
import hashlib
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
|
||||
import httpx
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives import padding, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
|
||||
from crop_image import validate_path
|
||||
from os import path as PATH
|
||||
|
||||
|
||||
class Crack:
|
||||
def __init__(self, gt=None, challenge=None):
|
||||
self.pic_path = None
|
||||
self.s = None
|
||||
self.c = None
|
||||
self.session = httpx.Client(http2=True)
|
||||
self.session.headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36"
|
||||
}
|
||||
# self.session.verify = False
|
||||
self.gt = gt
|
||||
self.challenge = challenge
|
||||
self.aeskey = ''.join(f'{int((1 + random.random()) * 65536):04x}'[1:] for _ in range(4))
|
||||
public_key = '''-----BEGIN PUBLIC KEY-----
|
||||
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDB45NNFhRGWzMFPn9I7k7IexS5
|
||||
XviJR3E9Je7L/350x5d9AtwdlFH3ndXRwQwprLaptNb7fQoCebZxnhdyVl8Jr2J3
|
||||
FZGSIa75GJnK4IwNaG10iyCjYDviMYymvCtZcGWSqSGdC/Bcn2UCOiHSMwgHJSrg
|
||||
Bm1Zzu+l8nSOqAurgQIDAQAB
|
||||
-----END PUBLIC KEY-----'''
|
||||
self.public_key = serialization.load_pem_public_key(public_key.encode())
|
||||
self.enc_key = self.public_key.encrypt(self.aeskey.encode(), PKCS1v15()).hex()
|
||||
with open("mousepath.json", "r") as f:
|
||||
self.mouse_path = json.loads(f.read())
|
||||
|
||||
def get_type(self) -> dict:
|
||||
url = f"https://api.geetest.com/gettype.php?gt={self.gt}"
|
||||
res = self.session.get(url)
|
||||
data = json.loads(res.text[1:-1])["data"]
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def encode(input_bytes: list):
|
||||
def get_char_from_index(index):
|
||||
char_table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789()"
|
||||
return char_table[index] if 0 <= index < len(char_table) else "."
|
||||
|
||||
def transform_value(value, bit_mask):
|
||||
result = 0
|
||||
for r in range(23, -1, -1):
|
||||
if (bit_mask >> r) & 1:
|
||||
result = (result << 1) + ((value >> r) & 1)
|
||||
return result
|
||||
|
||||
encoded_string = ""
|
||||
padding = ""
|
||||
input_length = len(input_bytes)
|
||||
for i in range(0, input_length, 3):
|
||||
chunk_length = min(3, input_length - i)
|
||||
chunk = input_bytes[i:i + chunk_length]
|
||||
if chunk_length == 3:
|
||||
value = (chunk[0] << 16) + (chunk[1] << 8) + chunk[2]
|
||||
encoded_string += get_char_from_index(transform_value(value, 7274496)) + get_char_from_index(
|
||||
transform_value(value, 9483264)) + get_char_from_index(
|
||||
transform_value(value, 19220)) + get_char_from_index(transform_value(value, 235))
|
||||
elif chunk_length == 2:
|
||||
value = (chunk[0] << 16) + (chunk[1] << 8)
|
||||
encoded_string += get_char_from_index(transform_value(value, 7274496)) + get_char_from_index(
|
||||
transform_value(value, 9483264)) + get_char_from_index(transform_value(value, 19220))
|
||||
padding = "."
|
||||
elif chunk_length == 1:
|
||||
value = chunk[0] << 16
|
||||
encoded_string += get_char_from_index(transform_value(value, 7274496)) + get_char_from_index(
|
||||
transform_value(value, 9483264))
|
||||
padding = ".."
|
||||
return encoded_string + padding
|
||||
|
||||
@staticmethod
|
||||
def MD5(text: str):
|
||||
return hashlib.md5(text.encode()).hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def encode_mouse_path(path: list, c: list, s: str):
|
||||
def preprocess(path: list):
|
||||
def BFIQ(e):
|
||||
t = 32767
|
||||
if not isinstance(e, int):
|
||||
return e
|
||||
else:
|
||||
if t < e:
|
||||
e = t
|
||||
elif e < -t:
|
||||
e = -t
|
||||
return round(e)
|
||||
|
||||
def BGAB(e):
|
||||
t = ''
|
||||
n = 0
|
||||
len(e or [])
|
||||
while n < len(e) and not t:
|
||||
if e[n]:
|
||||
t = e[n][4]
|
||||
n += 1
|
||||
if not t:
|
||||
return e
|
||||
r = ''
|
||||
i = ['mouse', 'touch', 'pointer', 'MSPointer']
|
||||
for s in range(len(i)):
|
||||
if t.startswith(i[s]):
|
||||
r = i[s]
|
||||
_ = list(e)
|
||||
for a in range(len(_) - 1, -1, -1):
|
||||
c = _[a]
|
||||
l = c[0]
|
||||
if l in ['move', 'down', 'up']:
|
||||
value = c[4] or ''
|
||||
if not value.startswith(r):
|
||||
_.pop(a)
|
||||
return _
|
||||
|
||||
t = 0
|
||||
n = 0
|
||||
r = []
|
||||
s = 0
|
||||
if len(path) <= 0:
|
||||
return []
|
||||
o = None
|
||||
_ = None
|
||||
a = BGAB(path)
|
||||
c = len(a)
|
||||
for l in range(0 if c < 300 else c - 300, c):
|
||||
u = a[l]
|
||||
h = u[0]
|
||||
if h in ['down', 'move', 'up', 'scroll']:
|
||||
if not o:
|
||||
o = u
|
||||
_ = u
|
||||
r.append([h, [u[1] - t, u[2] - n], BFIQ(u[3] - s if s else s)])
|
||||
t = u[1]
|
||||
n = u[2]
|
||||
s = u[3]
|
||||
elif h in ['blur', 'focus', 'unload']:
|
||||
r.append([h, BFIQ(u[1] - s if s else s)])
|
||||
s = u[1]
|
||||
return r
|
||||
|
||||
def process(prepared_path: list):
|
||||
h = {
|
||||
'move': 0,
|
||||
'down': 1,
|
||||
'up': 2,
|
||||
'scroll': 3,
|
||||
'focus': 4,
|
||||
'blur': 5,
|
||||
'unload': 6,
|
||||
'unknown': 7
|
||||
}
|
||||
|
||||
def p(e, t):
|
||||
n = bin(e)[2:]
|
||||
r = ''
|
||||
i = len(n) + 1
|
||||
while i <= t:
|
||||
i += 1
|
||||
r += '0'
|
||||
return r + n
|
||||
|
||||
def d(e):
|
||||
t = []
|
||||
n = len(e)
|
||||
r = 0
|
||||
while r < n:
|
||||
i = e[r]
|
||||
s = 0
|
||||
while True:
|
||||
if s >= 16:
|
||||
break
|
||||
o = r + s + 1
|
||||
if o >= n:
|
||||
break
|
||||
if e[o] != i:
|
||||
break
|
||||
s += 1
|
||||
r += 1 + s
|
||||
_ = h[i]
|
||||
if s != 0:
|
||||
t.append(_ | 8)
|
||||
t.append(s - 1)
|
||||
else:
|
||||
t.append(_)
|
||||
a = p(n | 32768, 16)
|
||||
c = ''
|
||||
for l in range(len(t)):
|
||||
c += p(t[l], 4)
|
||||
return a + c
|
||||
|
||||
def g(e, tt):
|
||||
def temp1(e1):
|
||||
n = len(e)
|
||||
r = 0
|
||||
i = []
|
||||
while r < n:
|
||||
s = 1
|
||||
o = e[r]
|
||||
_ = abs(o)
|
||||
while True:
|
||||
if n <= r + s:
|
||||
break
|
||||
if e[r + s] != o:
|
||||
break
|
||||
if (_ >= 127) or (s >= 127):
|
||||
break
|
||||
s += 1
|
||||
if s > 1:
|
||||
i.append((49152 if o < 0 else 32768) | s << 7 | _)
|
||||
else:
|
||||
i.append(o)
|
||||
r += s
|
||||
return i
|
||||
|
||||
e = temp1(e)
|
||||
|
||||
r = []
|
||||
i = []
|
||||
|
||||
def n(e, t):
|
||||
return 0 if e == 0 else math.log(e) / math.log(t)
|
||||
|
||||
for temp in e:
|
||||
t = math.ceil(n(abs(temp) + 1, 16))
|
||||
if t == 0:
|
||||
t = 1
|
||||
r.append(p(t - 1, 2))
|
||||
i.append(p(abs(temp), t * 4))
|
||||
|
||||
s = ''.join(r)
|
||||
o = ''.join(i)
|
||||
|
||||
def temp2(t):
|
||||
return t != 0 and t >> 15 != 1
|
||||
|
||||
def temp3(e1):
|
||||
n = []
|
||||
|
||||
def temp(e2):
|
||||
if temp2(e2):
|
||||
n.append(e2)
|
||||
|
||||
for r in range(len(e1)):
|
||||
temp(e1[r])
|
||||
return n
|
||||
|
||||
def temp4(t):
|
||||
if t < 0:
|
||||
return '1'
|
||||
else:
|
||||
return '0'
|
||||
|
||||
if tt:
|
||||
n = []
|
||||
e1 = temp3(e)
|
||||
for r in range(len(e1)):
|
||||
n.append(temp4(e1[r]))
|
||||
n = ''.join(n)
|
||||
else:
|
||||
n = ''
|
||||
return p(len(e) | 32768, 16) + s + o + n
|
||||
|
||||
def u(e):
|
||||
t = ''
|
||||
n = len(e) // 6
|
||||
for r in range(n):
|
||||
t += '()*,-./0123456789:?@ABCDEFGHIJKLMNOPQRSTUVWXYZ_abcdefghijklmnopqrstuvwxyz~'[
|
||||
int(e[6 * r: 6 * (r + 1)], 2)]
|
||||
return t
|
||||
|
||||
t = []
|
||||
n = []
|
||||
r = []
|
||||
i = []
|
||||
for a in range(len(prepared_path)):
|
||||
_ = prepared_path[a]
|
||||
a = len(_)
|
||||
t.append(_[0])
|
||||
n.append(_[1] if a == 2 else _[2])
|
||||
if a == 3:
|
||||
r.append(_[1][0])
|
||||
i.append(_[1][1])
|
||||
c = d(t) + g(n, False) + g(r, True) + g(i, True)
|
||||
l = len(c)
|
||||
if l % 6 != 0:
|
||||
c += p(0, 6 - l % 6)
|
||||
return u(c)
|
||||
|
||||
def postprocess(e, t, n):
|
||||
i = 0
|
||||
s = e
|
||||
o = t[0]
|
||||
_ = t[2]
|
||||
a = t[4]
|
||||
while True:
|
||||
r = n[i:i + 2]
|
||||
if not r:
|
||||
break
|
||||
i += 2
|
||||
c = int(r, 16)
|
||||
l = chr(c)
|
||||
u = (o * c * c + _ * c + a) % len(e)
|
||||
s = s[:u] + l + s[u:]
|
||||
return s
|
||||
|
||||
return postprocess(process(preprocess(path)), c, s)
|
||||
|
||||
def aes_encrypt(self, content: str):
|
||||
cipher = Cipher(algorithms.AES(self.aeskey.encode()), modes.CBC(b"0000000000000000"))
|
||||
encryptor = cipher.encryptor()
|
||||
padder = padding.PKCS7(128).padder()
|
||||
padded_data = padder.update(content.encode())
|
||||
padded_data += padder.finalize()
|
||||
ct = encryptor.update(padded_data) + encryptor.finalize()
|
||||
return ct
|
||||
|
||||
def get_c_s(self):
|
||||
o = {
|
||||
"gt": self.gt,
|
||||
"challenge": self.challenge,
|
||||
"offline": False,
|
||||
"new_captcha": True,
|
||||
"product": "embed",
|
||||
"width": "300px",
|
||||
"https": True,
|
||||
"protocol": "https://",
|
||||
}
|
||||
o.update(self.get_type())
|
||||
o.update({
|
||||
"cc": 16,
|
||||
"ww": True,
|
||||
"i": "-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1!!-1"
|
||||
})
|
||||
o = json.dumps(o, separators=(',', ':'))
|
||||
ct = self.aes_encrypt(o)
|
||||
s = []
|
||||
for byte in ct:
|
||||
s.append(byte)
|
||||
i = self.encode(s)
|
||||
r = self.enc_key
|
||||
w = i + r
|
||||
params = {
|
||||
"gt": self.gt,
|
||||
"challenge": self.challenge,
|
||||
"lang": "zh-cn",
|
||||
"pt": 0,
|
||||
"client_type": "web",
|
||||
"callback": "geetest_" + str(int(round(time.time() * 1000))),
|
||||
"w": w
|
||||
}
|
||||
resp = self.session.get("https://api.geetest.com/get.php", params=params).text
|
||||
data = json.loads(resp[22:-1])["data"]
|
||||
self.c = data["c"]
|
||||
self.s = data["s"]
|
||||
return data["c"], data["s"]
|
||||
|
||||
def gettype(self):
|
||||
url = f"https://api.geetest.com/gettype.php?gt={self.gt}&callback=geetest_{str(int(round(time.time() * 1000)))}"
|
||||
return self.session.get(url).text
|
||||
|
||||
def ajax(self):
|
||||
def transform(e, t, n):
|
||||
if not t or not n:
|
||||
return e
|
||||
o = 0
|
||||
i = list(e)
|
||||
s = t[0]
|
||||
a = t[2]
|
||||
b = t[4]
|
||||
while o < len(n):
|
||||
r = n[o:o + 2]
|
||||
o += 2
|
||||
c = int(r, 16)
|
||||
l = chr(c)
|
||||
u = (s * c * c + a * c + b) % len(i)
|
||||
i.insert(u, l)
|
||||
return ''.join(i)
|
||||
|
||||
mouse_path = [
|
||||
["move", 385, 313, 1724572150164, "pointermove"],
|
||||
["move", 385, 315, 1724572150166, "pointermove"],
|
||||
["move", 386, 315, 1724572150174, "pointermove"],
|
||||
["move", 387, 315, 1724572150182, "pointermove"],
|
||||
["move", 387, 316, 1724572150188, "pointermove"],
|
||||
["move", 388, 316, 1724572150204, "pointermove"],
|
||||
["move", 388, 317, 1724572150218, "pointermove"],
|
||||
["down", 388, 317, 1724572150586, "pointerdown"],
|
||||
["focus", 1724572150587],
|
||||
["up", 388, 317, 1724572150632, "pointerup"]
|
||||
]
|
||||
tt = transform(self.encode_mouse_path(mouse_path, self.c, self.s), self.c, self.s)
|
||||
rp = self.MD5(self.gt + self.challenge + self.s)
|
||||
temp1 = '''"lang":"zh-cn","type":"fullpage","tt":"%s","light":"DIV_0","s":"c7c3e21112fe4f741921cb3e4ff9f7cb","h":"321f9af1e098233dbd03f250fd2b5e21","hh":"39bd9cad9e425c3a8f51610fd506e3b3","hi":"09eb21b3ae9542a9bc1e8b63b3d9a467","vip_order":-1,"ct":-1,"ep":{"v":"9.1.9-dbjg5z","te":false,"me":true,"ven":"Google Inc. (Intel)","ren":"ANGLE (Intel, Intel(R) Iris(R) Xe Graphics (0x0000A7A0) Direct3D11 vs_5_0 ps_5_0, D3D11)","fp":["scroll",0,1602,1724571628498,null],"lp":["up",386,217,1724571629854,"pointerup"],"em":{"ph":0,"cp":0,"ek":"11","wd":1,"nt":0,"si":0,"sc":0},"tm":{"a":1724571567311,"b":1724571567549,"c":1724571567562,"d":0,"e":0,"f":1724571567312,"g":1724571567312,"h":1724571567312,"i":1724571567317,"j":1724571567423,"k":1724571567330,"l":1724571567423,"m":1724571567545,"n":1724571567547,"o":1724571567569,"p":1724571568259,"q":1724571568259,"r":1724571568261,"s":1724571570378,"t":1724571570378,"u":1724571570380},"dnf":"dnf","by":0},"passtime":1600,"rp":"%s",''' % (
|
||||
tt, rp)
|
||||
r = "{" + temp1 + '"captcha_token":"1198034057","du6o":"eyjf7nne"}'
|
||||
ct = self.aes_encrypt(r)
|
||||
s = [byte for byte in ct]
|
||||
w = self.encode(s)
|
||||
params = {
|
||||
"gt": self.gt,
|
||||
"challenge": self.challenge,
|
||||
"lang": "zh-cn",
|
||||
"pt": 0,
|
||||
"client_type": "web",
|
||||
"callback": "geetest_" + str(int(round(time.time() * 1000))),
|
||||
"w": w
|
||||
}
|
||||
resp = self.session.get("https://api.geetest.com/ajax.php", params=params).text
|
||||
return json.loads(resp[22:-1])["data"]
|
||||
|
||||
def get_pic(self):
|
||||
params = {
|
||||
"is_next": "true",
|
||||
"type": "click",
|
||||
"gt": self.gt,
|
||||
"challenge": self.challenge,
|
||||
"lang": "zh-cn",
|
||||
"https": "true",
|
||||
"protocol": "https://",
|
||||
"offline": "false",
|
||||
"product": "float",
|
||||
"api_server": "api.geevisit.com",
|
||||
"isPC": True,
|
||||
"autoReset": True,
|
||||
"width": "100%",
|
||||
"callback": "geetest_" + str(int(round(time.time() * 1000))),
|
||||
}
|
||||
resp = self.session.get("https://api.geevisit.com/get.php", params=params).text
|
||||
data = json.loads(resp[22:-1])["data"]
|
||||
self.pic_path = data["pic"]
|
||||
pic_url = "https://" + data["resource_servers"][0][:-1] + data["pic"]
|
||||
pic_data = self.session.get(pic_url).content
|
||||
pic_name = data["pic"].split("/")[-1]
|
||||
with open(PATH.join(validate_path,pic_name),'wb+') as f:
|
||||
f.write(pic_data)
|
||||
return pic_data,pic_name
|
||||
|
||||
def verify(self, points: list):
|
||||
u = self.enc_key
|
||||
o = {
|
||||
"lang": "zh-cn",
|
||||
"passtime": 1600,
|
||||
"a": ",".join(points),
|
||||
"pic": self.pic_path,
|
||||
"tt": self.encode_mouse_path(self.mouse_path, self.c, self.s),
|
||||
"ep": {
|
||||
"ca": [{"x": 524, "y": 209, "t": 0, "dt": 1819}, {"x": 558, "y": 299, "t": 0, "dt": 428},
|
||||
{"x": 563, "y": 95, "t": 0, "dt": 952}, {"x": 670, "y": 407, "t": 3, "dt": 892}],
|
||||
"v": '3.1.0',
|
||||
"$_FB": False,
|
||||
"me": True,
|
||||
"tm": {"a": 1724585496403, "b": 1724585496605, "c": 1724585496613, "d": 0, "e": 0, "f": 1724585496404,
|
||||
"g": 1724585496404, "h": 1724585496404, "i": 1724585496404, "j": 1724585496404, "k": 0,
|
||||
"l": 1724585496413, "m": 1724585496601, "n": 1724585496603, "o": 1724585496618,
|
||||
"p": 1724585496749, "q": 1724585496749, "r": 1724585496751, "s": 1724585498068,
|
||||
"t": 1724585498068, "u": 1724585498069}
|
||||
},
|
||||
"h9s9": "1816378497",
|
||||
}
|
||||
o["rp"] = self.MD5(self.gt + self.challenge + str(o["passtime"]))
|
||||
o = json.dumps(o, separators=(',', ':'))
|
||||
ct = self.aes_encrypt(o)
|
||||
s = []
|
||||
for byte in ct:
|
||||
s.append(byte)
|
||||
p = self.encode(s)
|
||||
w = p + u
|
||||
params = {
|
||||
"gt": self.gt,
|
||||
"challenge": self.challenge,
|
||||
"lang": "zh-cn",
|
||||
"pt": 0,
|
||||
"client_type": "web",
|
||||
"w": w
|
||||
}
|
||||
resp = self.session.get("https://api.geevisit.com/ajax.php", params=params).text
|
||||
return resp[1:-1]
|
||||
139
crop_image.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from PIL import Image, ImageFont, ImageDraw, ImageOps
|
||||
from io import BytesIO
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
current_path = os.getcwd()
|
||||
validate_path = os.path.join(current_path,'img_2_val')#要验证的图片暂存
|
||||
save_path = os.path.join(current_path,'img_saved')#存放历史图片,留作做数据集以待标记
|
||||
save_pass_path = os.path.join(save_path,'img_pass')#校验失败的图片,可能是轨迹有误,不一定是分类错误
|
||||
save_fail_path = os.path.join(save_path,'img_fail')#校验成功的图片,但有可能有个别分类错误
|
||||
os.makedirs(validate_path,exist_ok=True)
|
||||
os.makedirs(save_path,exist_ok=True)
|
||||
os.makedirs(save_pass_path,exist_ok=True)
|
||||
os.makedirs(save_fail_path,exist_ok=True)
|
||||
|
||||
def draw_points_on_image(bg_image, answer):
|
||||
# 将背景图片转换为OpenCV格式
|
||||
bg_image_cv = cv2.imdecode(np.frombuffer(bg_image, dtype=np.uint8), cv2.IMREAD_COLOR)
|
||||
|
||||
# 定义九宫格的大小和偏移量
|
||||
grid_width = 100
|
||||
grid_height = 86
|
||||
offset_x = 45
|
||||
offset_y = 38
|
||||
|
||||
for i, (row, col) in enumerate(answer):
|
||||
x = offset_x + (col-1) * grid_width
|
||||
y = offset_y + (row-1) * grid_height
|
||||
cv2.circle(bg_image_cv, (x, y), 10, (0, 0, 255), -1)
|
||||
cv2.imwrite('./img_2_val/predicted.jpg', bg_image_cv)#推理结果
|
||||
|
||||
def convert_png_to_jpg(png_bytes: bytes) -> bytes:
|
||||
# 将传入的 bytes 转换为图像对象
|
||||
png_image = Image.open(BytesIO(png_bytes))
|
||||
|
||||
# 创建一个 BytesIO 对象,用于存储输出的 JPG 数据
|
||||
output_bytes = BytesIO()
|
||||
|
||||
# 检查图像是否具有透明度通道 (RGBA)
|
||||
if png_image.mode == 'RGBA':
|
||||
# 创建白色背景
|
||||
white_bg = Image.new("RGB", png_image.size, (255, 255, 255))
|
||||
# 将 PNG 图像粘贴到白色背景上,透明部分用白色填充
|
||||
white_bg.paste(png_image, (0, 0), png_image)
|
||||
jpg_image = white_bg
|
||||
else:
|
||||
# 如果图像没有透明度,直接转换为 RGB 模式
|
||||
jpg_image = png_image.convert("RGB")
|
||||
|
||||
# 将转换后的图像保存为 JPG 格式到 BytesIO 对象
|
||||
jpg_image.save(output_bytes, format="JPEG")
|
||||
|
||||
# 返回保存后的 JPG 图像的 bytes
|
||||
return output_bytes.getvalue()
|
||||
|
||||
|
||||
def crop_image(image_bytes, coordinates):
|
||||
img = Image.open(BytesIO(image_bytes))
|
||||
width, height = img.size
|
||||
grid_width = width // 3
|
||||
grid_height = height // 3
|
||||
cropped_images = []
|
||||
for coord in coordinates:
|
||||
y, x = coord
|
||||
left = (x - 1) * grid_width
|
||||
upper = (y - 1) * grid_height
|
||||
right = left + grid_width
|
||||
lower = upper + grid_height
|
||||
box = (left, upper, right, lower)
|
||||
cropped_img = img.crop(box)
|
||||
cropped_images.append(cropped_img)
|
||||
return cropped_images
|
||||
|
||||
def crop_image_v3(image_bytes):
|
||||
coordinates = [
|
||||
[[0, 0], [112, 112]],
|
||||
[[116, 0], [228, 112]],
|
||||
[[232, 0], [344, 112]],#第一行
|
||||
[[0, 116], [112, 228]],
|
||||
[[116, 116], [228, 228]],
|
||||
[[232, 116], [344, 228]],#第二行
|
||||
[[0, 232], [112, 344]],
|
||||
[[116, 232], [228, 344]],
|
||||
[[232, 232], [344, 344]],#第三行
|
||||
[[2, 344], [42, 384]] #要验证的
|
||||
]
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
image = Image.fromarray(cv2.cvtColor(np.array(image), cv2.COLOR_RGBA2RGB))
|
||||
imageNew = Image.new('RGB', (300,261),(0,0,0))
|
||||
images = []
|
||||
for i, (start_point, end_point) in enumerate(coordinates):
|
||||
x1, y1 = start_point
|
||||
x2, y2 = end_point
|
||||
# 切割图像
|
||||
cropped_image = image.crop((x1, y1, x2, y2))
|
||||
images.append(cropped_image)
|
||||
# 保存切割后的图像
|
||||
output_path = os.path.join(validate_path,f'cropped_{i}.jpg')
|
||||
cropped_image.save(output_path)
|
||||
for i in range(3):
|
||||
imageNew.paste(images[i].resize((100,86)), (i*100, 0, (i+1)*100, 86))
|
||||
imageNew.paste(images[i+3].resize((100,86)), (i*100, 86, (i+1)*100, 172))
|
||||
imageNew.paste(images[i+6].resize((100,86)), (i*100,172, (i+1)*100, 258))
|
||||
imageNew.save(os.path.join(validate_path,f'nine.jpg') )
|
||||
if __name__ == "__main__":
|
||||
# v4测试代码
|
||||
# os.makedirs(os.path.join(current_path,'image_test'),exist_ok=True)
|
||||
# # 切割顺序,这里是从左到右,从上到下[x,y]
|
||||
# coordinates = [
|
||||
# [1, 1],
|
||||
# [1, 2],
|
||||
# [1, 3],
|
||||
# [2, 1],
|
||||
# [2, 2],
|
||||
# [2, 3],
|
||||
# [3, 1],
|
||||
# [3, 2],
|
||||
# [3, 3],
|
||||
# ]
|
||||
# with open("./image_test/bg.jpg", "rb") as rb:
|
||||
# bg_img = rb.read()
|
||||
# cropped_images = crop_image(bg_img, coordinates)
|
||||
# # 一个个保存下来
|
||||
# for j, img_crop in enumerate(cropped_images):
|
||||
# img_crop.save(f"./image_test/bg{j}.jpg")
|
||||
|
||||
# # 图标格式转换
|
||||
# with open("./image_test/icon.png", "rb") as rb:
|
||||
# icon_img = rb.read()
|
||||
# icon_img_jpg = convert_png_to_jpg(icon_img)
|
||||
# with open("./image_test/icon.jpg", "wb") as wb:
|
||||
# wb.write(icon_img_jpg)
|
||||
|
||||
# V3测试代码
|
||||
pic = "./img_saved/f105965489de434e930fa1ef8a5bcd9f.jpg"
|
||||
print("推理图片为:",pic)
|
||||
with open(pic, "rb") as f:
|
||||
img = f.read()
|
||||
crop_image_v3(img)
|
||||
BIN
img_2_val/nine.jpg
Normal file
|
After Width: | Height: | Size: 17 KiB |
BIN
img_2_val/predicted.jpg
Normal file
|
After Width: | Height: | Size: 31 KiB |
BIN
img_saved/00f58d85239040eea15a403d9426669b.jpg
Normal file
|
After Width: | Height: | Size: 58 KiB |
BIN
img_saved/0693707f08b44825b155ee8d72897d70.jpg
Normal file
|
After Width: | Height: | Size: 48 KiB |
BIN
img_saved/31158834904c4b808089058d39d239de.jpg
Normal file
|
After Width: | Height: | Size: 54 KiB |
BIN
img_saved/6d4753c615e54299863d837a321dea9c.jpg
Normal file
|
After Width: | Height: | Size: 59 KiB |
BIN
img_saved/73fa08e2f40242a9b1b82c2a478e0ee8.jpg
Normal file
|
After Width: | Height: | Size: 55 KiB |
BIN
img_saved/933762ddd6054c13a4d90740180afe51.jpg
Normal file
|
After Width: | Height: | Size: 45 KiB |
BIN
img_saved/9aa1728546524a7a8856cbb5db789880.jpg
Normal file
|
After Width: | Height: | Size: 61 KiB |
BIN
img_saved/d62f7aa9402547c4b17eecde127f10de.jpg
Normal file
|
After Width: | Height: | Size: 48 KiB |
BIN
img_saved/f105965489de434e930fa1ef8a5bcd9f.jpg
Normal file
|
After Width: | Height: | Size: 56 KiB |
|
After Width: | Height: | Size: 2.1 KiB |
|
After Width: | Height: | Size: 3.3 KiB |
|
After Width: | Height: | Size: 2.1 KiB |
|
After Width: | Height: | Size: 2.9 KiB |
|
After Width: | Height: | Size: 3.6 KiB |
|
After Width: | Height: | Size: 2.6 KiB |
|
After Width: | Height: | Size: 3.2 KiB |
|
After Width: | Height: | Size: 4.1 KiB |
|
After Width: | Height: | Size: 2.5 KiB |
|
After Width: | Height: | Size: 1023 B |
|
After Width: | Height: | Size: 3.3 KiB |
|
After Width: | Height: | Size: 2.0 KiB |
|
After Width: | Height: | Size: 2.9 KiB |
|
After Width: | Height: | Size: 2.9 KiB |
|
After Width: | Height: | Size: 2.6 KiB |
|
After Width: | Height: | Size: 3.9 KiB |
|
After Width: | Height: | Size: 3.4 KiB |
|
After Width: | Height: | Size: 2.7 KiB |
|
After Width: | Height: | Size: 2.5 KiB |
|
After Width: | Height: | Size: 1.2 KiB |
|
After Width: | Height: | Size: 3.4 KiB |
|
After Width: | Height: | Size: 3.3 KiB |
|
After Width: | Height: | Size: 2.8 KiB |
|
After Width: | Height: | Size: 3.2 KiB |
|
After Width: | Height: | Size: 2.4 KiB |
|
After Width: | Height: | Size: 2.1 KiB |
|
After Width: | Height: | Size: 2.2 KiB |
|
After Width: | Height: | Size: 2.3 KiB |
|
After Width: | Height: | Size: 1.4 KiB |
|
After Width: | Height: | Size: 933 B |
|
After Width: | Height: | Size: 3.4 KiB |
|
After Width: | Height: | Size: 3.3 KiB |
|
After Width: | Height: | Size: 3.4 KiB |
|
After Width: | Height: | Size: 3.2 KiB |
|
After Width: | Height: | Size: 3.6 KiB |
|
After Width: | Height: | Size: 3.3 KiB |
|
After Width: | Height: | Size: 3.1 KiB |
|
After Width: | Height: | Size: 2.7 KiB |
|
After Width: | Height: | Size: 3.8 KiB |
|
After Width: | Height: | Size: 984 B |
|
After Width: | Height: | Size: 3.3 KiB |
|
After Width: | Height: | Size: 3.1 KiB |
|
After Width: | Height: | Size: 1.9 KiB |
|
After Width: | Height: | Size: 1.7 KiB |
|
After Width: | Height: | Size: 3.4 KiB |
|
After Width: | Height: | Size: 2.9 KiB |
|
After Width: | Height: | Size: 3.8 KiB |
|
After Width: | Height: | Size: 2.4 KiB |
|
After Width: | Height: | Size: 2.1 KiB |
|
After Width: | Height: | Size: 949 B |
|
After Width: | Height: | Size: 4.2 KiB |
|
After Width: | Height: | Size: 3.3 KiB |
|
After Width: | Height: | Size: 2.2 KiB |
|
After Width: | Height: | Size: 3.1 KiB |
|
After Width: | Height: | Size: 2.4 KiB |
|
After Width: | Height: | Size: 3.2 KiB |
|
After Width: | Height: | Size: 2.5 KiB |
|
After Width: | Height: | Size: 3.0 KiB |
|
After Width: | Height: | Size: 3.9 KiB |
|
After Width: | Height: | Size: 1.0 KiB |
|
After Width: | Height: | Size: 2.6 KiB |
|
After Width: | Height: | Size: 4.1 KiB |
|
After Width: | Height: | Size: 4.3 KiB |
|
After Width: | Height: | Size: 2.9 KiB |
|
After Width: | Height: | Size: 2.1 KiB |
|
After Width: | Height: | Size: 2.1 KiB |
|
After Width: | Height: | Size: 2.8 KiB |
|
After Width: | Height: | Size: 2.6 KiB |
|
After Width: | Height: | Size: 1.4 KiB |
|
After Width: | Height: | Size: 1.0 KiB |
66
main.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
from crack import Crack
|
||||
from crop_image import crop_image_v3,save_path,save_fail_path,save_pass_path,validate_path
|
||||
import httpx
|
||||
from fastapi import FastAPI,Query
|
||||
from fastapi.responses import JSONResponse
|
||||
import shutil
|
||||
import os
|
||||
|
||||
port = 9645
|
||||
# api
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/pass_nine")
|
||||
def get_pic(gt: str = Query(...), challenge: str = Query(...), point: str = Query(default=None)):
|
||||
print(f"开始获取:\ngt:{gt}\nchallenge:{challenge}")
|
||||
t = time.time()
|
||||
|
||||
crack = Crack(gt, challenge)
|
||||
|
||||
crack.gettype()
|
||||
|
||||
crack.get_c_s()
|
||||
|
||||
time.sleep(random.uniform(0.4,0.6))
|
||||
|
||||
crack.ajax()
|
||||
|
||||
pic_content,pic_name = crack.get_pic()
|
||||
|
||||
crop_image_v3(pic_content)
|
||||
|
||||
with open(f"{validate_path}/cropped_9.jpg", "rb") as rb:
|
||||
icon_image = rb.read()
|
||||
with open(f"{validate_path}/nine.jpg", "rb") as rb:
|
||||
bg_image = rb.read()
|
||||
|
||||
result_list = predict_onnx(icon_image, bg_image, point)
|
||||
point_list = [f"{col}_{row}" for row, col in result_list]
|
||||
wait_time = 4.0 - (time.time() - t)
|
||||
time.sleep(wait_time)
|
||||
result = json.loads(crack.verify(point_list))
|
||||
shutil.move(os.path.join(validate_path,pic_name),os.path.join(save_path,pic_name))
|
||||
if 'validate' in result['data']:
|
||||
path_2_save = os.path.join(save_pass_path,pic_name.split('.')[0])
|
||||
else:
|
||||
path_2_save = os.path.join(save_fail_path,pic_name.split('.')[0])
|
||||
os.makedirs(path_2_save,exist_ok=True)
|
||||
for pic in os.listdir(validate_path):
|
||||
if pic.startswith('cropped'):
|
||||
shutil.move(os.path.join(validate_path,pic),os.path.join(path_2_save,pic))
|
||||
total_time = time.time() - t
|
||||
print(f"总计耗时(含等待{wait_time}s): {total_time}\n{result}")
|
||||
return JSONResponse(content=result)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from predict import predict_onnx
|
||||
import uvicorn
|
||||
print(f"{' '*10}api: http://127.0.0.1:{port}/pass_nine{' '*10}")
|
||||
print(f"{' '*10}api所需参数:gt、challenge、point(可选){' '*10}")
|
||||
uvicorn.run(app,port=port)
|
||||
2337
mousepath.json
Normal file
159
predict.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from train import MyResNet18, data_transform
|
||||
from crop_image import crop_image, convert_png_to_jpg,draw_points_on_image
|
||||
import torch
|
||||
import time
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
def predict(icon_image, bg_image):
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(current_dir, 'model', 'resnet18_38_0.021147585306924.pth')
|
||||
coordinates = [
|
||||
[1, 1],
|
||||
[1, 2],
|
||||
[1, 3],
|
||||
[2, 1],
|
||||
[2, 2],
|
||||
[2, 3],
|
||||
[3, 1],
|
||||
[3, 2],
|
||||
[3, 3],
|
||||
]
|
||||
target_images = []
|
||||
target_images.append(data_transform(Image.open(BytesIO(icon_image))))
|
||||
|
||||
bg_images = crop_image(bg_image, coordinates)
|
||||
for bg_image in bg_images:
|
||||
target_images.append(data_transform(bg_image))
|
||||
|
||||
start = time.time()
|
||||
model = MyResNet18(num_classes=91) # 这里的类别数要与训练时一致
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
model.eval()
|
||||
print("加载模型,耗时:", time.time() - start)
|
||||
start = time.time()
|
||||
|
||||
target_images = torch.stack(target_images, dim=0)
|
||||
target_outputs = model(target_images)
|
||||
|
||||
scores = []
|
||||
|
||||
for i, out_put in enumerate(target_outputs):
|
||||
if i == 0:
|
||||
# 增加维度,以便于计算
|
||||
target_output = out_put.unsqueeze(0)
|
||||
else:
|
||||
similarity = torch.nn.functional.cosine_similarity(
|
||||
target_output, out_put.unsqueeze(0)
|
||||
)
|
||||
scores.append(similarity.cpu().item())
|
||||
# 从左到右,从上到下,依次为每张图片的置信度
|
||||
print(scores)
|
||||
# 对数组进行排序,保持下标
|
||||
indexed_arr = list(enumerate(scores))
|
||||
sorted_arr = sorted(indexed_arr, key=lambda x: x[1], reverse=True)
|
||||
# 提取最大三个数及其下标
|
||||
largest_three = sorted_arr[:3]
|
||||
print(largest_three)
|
||||
print("识别完成,耗时:", time.time() - start)
|
||||
|
||||
|
||||
# 加载onnx模型
|
||||
start = time.time()
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
model_path = os.path.join(current_dir, 'model', 'resnet18.onnx')
|
||||
session = ort.InferenceSession(model_path)
|
||||
input_name = session.get_inputs()[0].name
|
||||
print("加载模型,耗时:", time.time() - start)
|
||||
|
||||
|
||||
def predict_onnx(icon_image, bg_image, point = None):
|
||||
coordinates = [
|
||||
[1, 1],
|
||||
[1, 2],
|
||||
[1, 3],
|
||||
[2, 1],
|
||||
[2, 2],
|
||||
[2, 3],
|
||||
[3, 1],
|
||||
[3, 2],
|
||||
[3, 3],
|
||||
]
|
||||
|
||||
def cosine_similarity(vec1, vec2):
|
||||
# 将输入转换为 NumPy 数组
|
||||
vec1 = np.array(vec1)
|
||||
vec2 = np.array(vec2)
|
||||
# 计算点积
|
||||
dot_product = np.dot(vec1, vec2)
|
||||
# 计算向量的范数
|
||||
norm_vec1 = np.linalg.norm(vec1)
|
||||
norm_vec2 = np.linalg.norm(vec2)
|
||||
# 计算余弦相似度
|
||||
similarity = dot_product / (norm_vec1 * norm_vec2)
|
||||
return similarity
|
||||
|
||||
def data_transforms(image):
|
||||
image = image.resize((224, 224))
|
||||
image = Image.fromarray(cv2.cvtColor(np.array(image), cv2.COLOR_RGBA2RGB))
|
||||
image_array = np.array(image)
|
||||
image_array = image_array.astype(np.float32) / 255.0
|
||||
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
||||
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||
image_array = (image_array - mean) / std
|
||||
image_array = np.transpose(image_array, (2, 0, 1))
|
||||
# image_array = np.expand_dims(image_array, axis=0)
|
||||
return image_array
|
||||
|
||||
target_images = []
|
||||
target_images.append(data_transforms(Image.open(BytesIO(icon_image))))
|
||||
bg_images = crop_image(bg_image, coordinates)
|
||||
|
||||
for one in bg_images:
|
||||
target_images.append(data_transforms(one))
|
||||
|
||||
start = time.time()
|
||||
outputs = session.run(None, {input_name: target_images})[0]
|
||||
|
||||
scores = []
|
||||
for i, out_put in enumerate(outputs):
|
||||
if i == 0:
|
||||
target_output = out_put
|
||||
else:
|
||||
similarity = cosine_similarity(target_output, out_put)
|
||||
scores.append(similarity)
|
||||
# 从左到右,从上到下,依次为每张图片的置信度
|
||||
# print(scores)
|
||||
# 对数组进行排序,保持下标
|
||||
indexed_arr = list(enumerate(scores))
|
||||
sorted_arr = sorted(indexed_arr, key=lambda x: x[1], reverse=True)
|
||||
# 提取最大三个数及其下标
|
||||
if point == None:
|
||||
largest_three = sorted_arr[:3]
|
||||
answer = [coordinates[i[0]] for i in largest_three]
|
||||
# 基于分数判断
|
||||
else:
|
||||
answer = [one[0] for one in sorted_arr if one[1] > point]
|
||||
print(f"识别完成{answer},耗时: {time.time() - start}")
|
||||
draw_points_on_image(bg_image, answer)
|
||||
return answer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
icon_path = "img_2_val/cropped_9.jpg"
|
||||
bg_path = "img_2_val/nine.jpg"
|
||||
with open(icon_path, "rb") as rb:
|
||||
if icon_path.endswith('.png'):
|
||||
icon_image = convert_png_to_jpg(rb.read())
|
||||
else:
|
||||
icon_image = rb.read()
|
||||
with open(bg_path, "rb") as rb:
|
||||
bg_image = rb.read()
|
||||
predict_onnx(icon_image, bg_image)
|
||||
11
requirements.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
httpx
|
||||
cryptography
|
||||
onnxruntime
|
||||
opencv-python
|
||||
numpy
|
||||
torch
|
||||
torchvision
|
||||
Pillow
|
||||
matplotlib
|
||||
tqdm
|
||||
shutil
|
||||
119
train.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import torchvision.transforms as transforms
|
||||
from matplotlib import pyplot as plt
|
||||
from torchvision.datasets import ImageFolder
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torchvision
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
import os
|
||||
os.makedirs(os.path.join(os.getcwd(),'model'),exist_ok=True)
|
||||
|
||||
# 定义数据转换
|
||||
data_transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((224, 224)), # 调整图像大小
|
||||
transforms.ToTensor(), # 将图像转换为张量
|
||||
transforms.Normalize(
|
||||
(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
|
||||
), # 标准化图像
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# 定义数据集
|
||||
class CustomDataset:
|
||||
def __init__(self, data_dir):
|
||||
self.dataset = ImageFolder(root=data_dir, transform=data_transform)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
image, label = self.dataset[idx]
|
||||
return image, label
|
||||
|
||||
|
||||
class MyResNet18(torch.nn.Module):
|
||||
def __init__(self, num_classes):
|
||||
super(MyResNet18, self).__init__()
|
||||
self.resnet = torchvision.models.resnet18(pretrained=True)
|
||||
self.resnet.fc = nn.Linear(512, num_classes) # 修改这里的输入大小为512
|
||||
|
||||
def forward(self, x):
|
||||
return self.resnet(x)
|
||||
|
||||
|
||||
def train(epoch):
|
||||
print("judge the cuda: " + str(torch.version.cuda))
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print("this train use devices: " + str(device))
|
||||
|
||||
data_dir = "dataset"
|
||||
# 自定义数据集实例
|
||||
custom_dataset = CustomDataset(data_dir)
|
||||
# 数据加载器
|
||||
batch_size = 64
|
||||
data_loader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
# 初始化模型 num_classes就是目录下的子文件夹数目,每个子文件夹对应一个分类,模型输出的向量长度也是这个长度
|
||||
model = MyResNet18(num_classes=91)
|
||||
model.to(device)
|
||||
|
||||
# 损失函数
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
# 优化器
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
|
||||
|
||||
epoch_losses = []
|
||||
# 训练模型
|
||||
for i in range(epoch):
|
||||
losses = []
|
||||
|
||||
# 迭代器进度条
|
||||
data_loader_tqdm = tqdm(data_loader)
|
||||
|
||||
epoch_loss = 0
|
||||
for inputs, labels in data_loader_tqdm:
|
||||
# 将输入数据和标签传输到指定的计算设备(如 GPU 或 CPU)
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
|
||||
# 梯度更新之前将所有模型参数的梯度置为零,防止梯度累积
|
||||
optimizer.zero_grad()
|
||||
|
||||
# 前向传播:将输入数据传入模型,计算输出
|
||||
outputs = model(inputs)
|
||||
|
||||
# 根据模型的输出和实际标签计算损失值
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
# 将当前批次的损失值记录到 losses 列表中,以便后续计算平均损失
|
||||
losses.append(loss.item())
|
||||
epoch_loss = np.mean(losses)
|
||||
data_loader_tqdm.set_description(
|
||||
f"This epoch is {str(i + 1)} and it's loss is {loss.item()}, average loss {epoch_loss}"
|
||||
)
|
||||
|
||||
# 反向传播:根据当前损失值计算模型参数的梯度
|
||||
loss.backward()
|
||||
# 使用优化器更新模型参数,根据梯度调整模型参数
|
||||
optimizer.step()
|
||||
epoch_losses.append(epoch_loss)
|
||||
# 每过一个batch就保存一次模型
|
||||
torch.save(model.state_dict(), f'model/resnet18_{str(i + 1)}_{epoch_loss}.pth')
|
||||
|
||||
# loss 变化绘制代码
|
||||
data = np.array(epoch_losses)
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(data)
|
||||
plt.title(f"{epoch} epoch loss change")
|
||||
plt.xlabel("epoch")
|
||||
plt.ylabel("Loss")
|
||||
# 显示图像
|
||||
plt.show()
|
||||
print(f"completed. Model saved.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train(40)
|
||||