mirror of
https://github.com/zc-zhangchen/any-auto-register.git
synced 2026-05-16 11:06:45 +08:00
252 lines
8.5 KiB
Python
252 lines
8.5 KiB
Python
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||
from fastapi.responses import StreamingResponse
|
||
from pydantic import BaseModel, Field
|
||
from sqlmodel import Session, select
|
||
from typing import Optional
|
||
from core.db import TaskLog, engine
|
||
import time, json, asyncio, threading
|
||
|
||
router = APIRouter(prefix="/tasks", tags=["tasks"])
|
||
|
||
_tasks: dict = {}
|
||
_tasks_lock = threading.Lock()
|
||
|
||
MAX_FINISHED_TASKS = 200
|
||
CLEANUP_THRESHOLD = 250
|
||
|
||
|
||
def _cleanup_old_tasks():
|
||
"""Remove oldest finished tasks when the dict grows too large."""
|
||
with _tasks_lock:
|
||
finished = [
|
||
(tid, t) for tid, t in _tasks.items()
|
||
if t.get("status") in ("done", "failed")
|
||
]
|
||
if len(finished) <= MAX_FINISHED_TASKS:
|
||
return
|
||
finished.sort(key=lambda x: x[0])
|
||
to_remove = finished[: len(finished) - MAX_FINISHED_TASKS]
|
||
for tid, _ in to_remove:
|
||
del _tasks[tid]
|
||
|
||
|
||
class RegisterTaskRequest(BaseModel):
|
||
platform: str
|
||
email: Optional[str] = None
|
||
password: Optional[str] = None
|
||
count: int = 1
|
||
concurrency: int = 1
|
||
proxy: Optional[str] = None
|
||
executor_type: str = "protocol"
|
||
captcha_solver: str = "yescaptcha"
|
||
extra: dict = Field(default_factory=dict)
|
||
|
||
|
||
def _log(task_id: str, msg: str):
|
||
"""向任务追加一条日志"""
|
||
ts = time.strftime("%H:%M:%S")
|
||
entry = f"[{ts}] {msg}"
|
||
with _tasks_lock:
|
||
if task_id in _tasks:
|
||
_tasks[task_id].setdefault("logs", []).append(entry)
|
||
print(entry)
|
||
|
||
|
||
def _save_task_log(platform: str, email: str, status: str,
|
||
error: str = "", detail: dict = None):
|
||
"""Write a TaskLog record to the database."""
|
||
with Session(engine) as s:
|
||
log = TaskLog(
|
||
platform=platform,
|
||
email=email,
|
||
status=status,
|
||
error=error,
|
||
detail_json=json.dumps(detail or {}, ensure_ascii=False),
|
||
)
|
||
s.add(log)
|
||
s.commit()
|
||
|
||
|
||
def _auto_upload_cpa(task_id: str, account):
|
||
"""注册成功后自动上传 CPA(仅 chatgpt 平台,且已配置时)"""
|
||
if getattr(account, "platform", "") != "chatgpt":
|
||
return
|
||
try:
|
||
from core.config_store import config_store
|
||
cpa_url = config_store.get("cpa_api_url", "")
|
||
if cpa_url:
|
||
from platforms.chatgpt.cpa_upload import generate_token_json, upload_to_cpa
|
||
|
||
class _A: pass
|
||
a = _A()
|
||
a.email = account.email
|
||
extra = account.extra or {}
|
||
a.access_token = extra.get("access_token") or account.token
|
||
a.refresh_token = extra.get("refresh_token", "")
|
||
a.id_token = extra.get("id_token", "")
|
||
|
||
token_data = generate_token_json(a)
|
||
ok, msg = upload_to_cpa(token_data)
|
||
_log(task_id, f" [CPA] {'✓ ' + msg if ok else '✗ ' + msg}")
|
||
except Exception as e:
|
||
_log(task_id, f" [CPA] 自动上传异常: {e}")
|
||
|
||
|
||
def _run_register(task_id: str, req: RegisterTaskRequest):
|
||
from core.registry import get
|
||
from core.base_platform import RegisterConfig
|
||
from core.db import save_account
|
||
from core.base_mailbox import create_mailbox
|
||
|
||
with _tasks_lock:
|
||
_tasks[task_id]["status"] = "running"
|
||
success = 0
|
||
errors = []
|
||
|
||
try:
|
||
PlatformCls = get(req.platform)
|
||
config = RegisterConfig(
|
||
executor_type=req.executor_type,
|
||
captcha_solver=req.captcha_solver,
|
||
proxy=req.proxy,
|
||
extra=req.extra,
|
||
)
|
||
mailbox = create_mailbox(
|
||
provider=req.extra.get("mail_provider", "laoudo"),
|
||
extra=req.extra,
|
||
proxy=req.proxy,
|
||
)
|
||
def _do_one(i: int):
|
||
from core.proxy_pool import proxy_pool
|
||
_proxy = req.proxy
|
||
if not _proxy:
|
||
_proxy = proxy_pool.get_next()
|
||
_config = RegisterConfig(
|
||
executor_type=req.executor_type,
|
||
captcha_solver=req.captcha_solver,
|
||
proxy=_proxy,
|
||
extra=req.extra,
|
||
)
|
||
_mailbox = mailbox.__class__(**mailbox.__dict__) if req.concurrency > 1 else mailbox
|
||
_platform = PlatformCls(config=_config, mailbox=_mailbox)
|
||
_platform._log_fn = lambda msg: _log(task_id, msg)
|
||
try:
|
||
with _tasks_lock:
|
||
_tasks[task_id]["progress"] = f"{i+1}/{req.count}"
|
||
_log(task_id, f"开始注册第 {i+1}/{req.count} 个账号")
|
||
if _proxy: _log(task_id, f"使用代理: {_proxy}")
|
||
account = _platform.register(
|
||
email=req.email or None,
|
||
password=req.password,
|
||
)
|
||
save_account(account)
|
||
if _proxy: proxy_pool.report_success(_proxy)
|
||
_log(task_id, f"✓ 注册成功: {account.email}")
|
||
_save_task_log(req.platform, account.email, "success")
|
||
_auto_upload_cpa(task_id, account)
|
||
cashier_url = (account.extra or {}).get("cashier_url", "")
|
||
if cashier_url:
|
||
_log(task_id, f" [升级链接] {cashier_url}")
|
||
with _tasks_lock:
|
||
_tasks[task_id].setdefault("cashier_urls", []).append(cashier_url)
|
||
return True
|
||
except Exception as e:
|
||
if _proxy: proxy_pool.report_fail(_proxy)
|
||
_log(task_id, f"✗ 注册失败: {e}")
|
||
_save_task_log(req.platform, req.email or "", "failed", error=str(e))
|
||
return str(e)
|
||
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
max_workers = min(req.concurrency, req.count, 5)
|
||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||
futures = [pool.submit(_do_one, i) for i in range(req.count)]
|
||
for f in as_completed(futures):
|
||
result = f.result()
|
||
if result is True:
|
||
success += 1
|
||
else:
|
||
errors.append(result)
|
||
except Exception as e:
|
||
_log(task_id, f"致命错误: {e}")
|
||
with _tasks_lock:
|
||
_tasks[task_id]["status"] = "failed"
|
||
_tasks[task_id]["error"] = str(e)
|
||
return
|
||
|
||
with _tasks_lock:
|
||
_tasks[task_id]["status"] = "done"
|
||
_tasks[task_id]["success"] = success
|
||
_tasks[task_id]["errors"] = errors
|
||
_log(task_id, f"完成: 成功 {success} 个, 失败 {len(errors)} 个")
|
||
_cleanup_old_tasks()
|
||
|
||
|
||
@router.post("/register")
|
||
def create_register_task(
|
||
req: RegisterTaskRequest,
|
||
background_tasks: BackgroundTasks,
|
||
):
|
||
task_id = f"task_{int(time.time()*1000)}"
|
||
with _tasks_lock:
|
||
_tasks[task_id] = {"id": task_id, "status": "pending",
|
||
"progress": f"0/{req.count}", "logs": []}
|
||
background_tasks.add_task(_run_register, task_id, req)
|
||
return {"task_id": task_id}
|
||
|
||
|
||
@router.get("/logs")
|
||
def get_logs(platform: str = None, page: int = 1, page_size: int = 50):
|
||
with Session(engine) as s:
|
||
q = select(TaskLog)
|
||
if platform:
|
||
q = q.where(TaskLog.platform == platform)
|
||
q = q.order_by(TaskLog.id.desc())
|
||
total = len(s.exec(q).all())
|
||
items = s.exec(q.offset((page - 1) * page_size).limit(page_size)).all()
|
||
return {"total": total, "items": items}
|
||
|
||
|
||
@router.get("/{task_id}/logs/stream")
|
||
async def stream_logs(task_id: str, since: int = 0):
|
||
"""SSE 实时日志流"""
|
||
with _tasks_lock:
|
||
if task_id not in _tasks:
|
||
raise HTTPException(404, "任务不存在")
|
||
|
||
async def event_generator():
|
||
sent = since
|
||
while True:
|
||
with _tasks_lock:
|
||
logs = list(_tasks.get(task_id, {}).get("logs", []))
|
||
status = _tasks.get(task_id, {}).get("status", "")
|
||
while sent < len(logs):
|
||
yield f"data: {json.dumps({'line': logs[sent]})}\n\n"
|
||
sent += 1
|
||
if status in ("done", "failed"):
|
||
yield f"data: {json.dumps({'done': True, 'status': status})}\n\n"
|
||
break
|
||
await asyncio.sleep(0.5)
|
||
|
||
return StreamingResponse(
|
||
event_generator(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"X-Accel-Buffering": "no",
|
||
},
|
||
)
|
||
|
||
|
||
@router.get("/{task_id}")
|
||
def get_task(task_id: str):
|
||
with _tasks_lock:
|
||
if task_id not in _tasks:
|
||
raise HTTPException(404, "任务不存在")
|
||
return _tasks[task_id]
|
||
|
||
|
||
@router.get("")
|
||
def list_tasks():
|
||
with _tasks_lock:
|
||
return list(_tasks.values())
|