mirror of
https://github.com/zc-zhangchen/any-auto-register.git
synced 2026-05-08 08:14:04 +08:00
385 lines
14 KiB
Python
385 lines
14 KiB
Python
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel, Field
|
|
from sqlmodel import Session, select
|
|
from typing import Optional
|
|
from copy import deepcopy
|
|
from core.db import TaskLog, engine
|
|
import time, json, asyncio, threading, logging
|
|
|
|
router = APIRouter(prefix="/tasks", tags=["tasks"])
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_tasks: dict = {}
|
|
_tasks_lock = threading.Lock()
|
|
|
|
MAX_FINISHED_TASKS = 200
|
|
CLEANUP_THRESHOLD = 250
|
|
|
|
|
|
def _cleanup_old_tasks():
|
|
"""Remove oldest finished tasks when the dict grows too large."""
|
|
with _tasks_lock:
|
|
finished = [
|
|
(tid, t) for tid, t in _tasks.items()
|
|
if t.get("status") in ("done", "failed")
|
|
]
|
|
if len(finished) <= MAX_FINISHED_TASKS:
|
|
return
|
|
finished.sort(key=lambda x: x[0])
|
|
to_remove = finished[: len(finished) - MAX_FINISHED_TASKS]
|
|
for tid, _ in to_remove:
|
|
del _tasks[tid]
|
|
|
|
|
|
class RegisterTaskRequest(BaseModel):
|
|
platform: str
|
|
email: Optional[str] = None
|
|
password: Optional[str] = None
|
|
count: int = 1
|
|
concurrency: int = 1
|
|
register_delay_seconds: float = 0
|
|
proxy: Optional[str] = None
|
|
executor_type: str = "protocol"
|
|
captcha_solver: str = "yescaptcha"
|
|
extra: dict = Field(default_factory=dict)
|
|
|
|
|
|
class TaskLogBatchDeleteRequest(BaseModel):
|
|
ids: list[int]
|
|
|
|
|
|
def _prepare_register_request(req: RegisterTaskRequest) -> RegisterTaskRequest:
|
|
from core.config_store import config_store
|
|
|
|
req_data = req.model_dump()
|
|
req_data["extra"] = deepcopy(req_data.get("extra") or {})
|
|
prepared = RegisterTaskRequest(**req_data)
|
|
|
|
mail_provider = prepared.extra.get("mail_provider") or config_store.get("mail_provider", "")
|
|
if mail_provider == "luckmail":
|
|
platform = prepared.platform
|
|
if platform in ("tavily", "openblocklabs"):
|
|
raise HTTPException(400, f"LuckMail 渠道暂时不支持 {platform} 项目注册")
|
|
|
|
mapping = {
|
|
"trae": "trae",
|
|
"cursor": "cursor",
|
|
"grok": "grok",
|
|
"kiro": "kiro",
|
|
"chatgpt": "openai"
|
|
}
|
|
prepared.extra["luckmail_project_code"] = mapping.get(platform, platform)
|
|
|
|
return prepared
|
|
|
|
|
|
def _create_task_record(task_id: str, req: RegisterTaskRequest, source: str, meta: dict | None = None):
|
|
with _tasks_lock:
|
|
_tasks[task_id] = {
|
|
"id": task_id,
|
|
"status": "pending",
|
|
"platform": req.platform,
|
|
"source": source,
|
|
"meta": meta or {},
|
|
"progress": f"0/{req.count}",
|
|
"logs": [],
|
|
}
|
|
|
|
|
|
def enqueue_register_task(
|
|
req: RegisterTaskRequest,
|
|
*,
|
|
background_tasks: BackgroundTasks | None = None,
|
|
source: str = "manual",
|
|
meta: dict | None = None,
|
|
) -> str:
|
|
prepared = _prepare_register_request(req)
|
|
task_id = f"task_{int(time.time()*1000)}"
|
|
_create_task_record(task_id, prepared, source, meta)
|
|
if background_tasks is None:
|
|
thread = threading.Thread(target=_run_register, args=(task_id, prepared), daemon=True)
|
|
thread.start()
|
|
else:
|
|
background_tasks.add_task(_run_register, task_id, prepared)
|
|
return task_id
|
|
|
|
|
|
def has_active_register_task(*, platform: str | None = None, source: str | None = None) -> bool:
|
|
with _tasks_lock:
|
|
for task in _tasks.values():
|
|
if task.get("status") not in ("pending", "running"):
|
|
continue
|
|
if platform and task.get("platform") != platform:
|
|
continue
|
|
if source and task.get("source") != source:
|
|
continue
|
|
return True
|
|
return False
|
|
|
|
|
|
def _log(task_id: str, msg: str):
|
|
"""向任务追加一条日志"""
|
|
ts = time.strftime("%H:%M:%S")
|
|
entry = f"[{ts}] {msg}"
|
|
with _tasks_lock:
|
|
if task_id in _tasks:
|
|
_tasks[task_id].setdefault("logs", []).append(entry)
|
|
print(entry)
|
|
|
|
|
|
def _save_task_log(platform: str, email: str, status: str,
|
|
error: str = "", detail: dict = None):
|
|
"""Write a TaskLog record to the database."""
|
|
with Session(engine) as s:
|
|
log = TaskLog(
|
|
platform=platform,
|
|
email=email,
|
|
status=status,
|
|
error=error,
|
|
detail_json=json.dumps(detail or {}, ensure_ascii=False),
|
|
)
|
|
s.add(log)
|
|
s.commit()
|
|
|
|
|
|
def _auto_upload_integrations(task_id: str, account):
|
|
"""注册成功后自动导入外部系统。"""
|
|
try:
|
|
from services.external_sync import sync_account
|
|
|
|
for result in sync_account(account):
|
|
name = result.get("name", "Auto Upload")
|
|
ok = bool(result.get("ok"))
|
|
msg = result.get("msg", "")
|
|
_log(task_id, f" [{name}] {'✓ ' + msg if ok else '✗ ' + msg}")
|
|
except Exception as e:
|
|
_log(task_id, f" [Auto Upload] 自动导入异常: {e}")
|
|
|
|
|
|
def _run_register(task_id: str, req: RegisterTaskRequest):
|
|
from core.registry import get
|
|
from core.base_platform import RegisterConfig
|
|
from core.db import save_account
|
|
from core.base_mailbox import create_mailbox
|
|
|
|
with _tasks_lock:
|
|
_tasks[task_id]["status"] = "running"
|
|
success = 0
|
|
errors = []
|
|
start_gate_lock = threading.Lock()
|
|
next_start_time = time.time()
|
|
|
|
try:
|
|
PlatformCls = get(req.platform)
|
|
|
|
def _build_mailbox(proxy: Optional[str]):
|
|
from core.config_store import config_store
|
|
merged_extra = config_store.get_all().copy()
|
|
merged_extra.update({k: v for k, v in req.extra.items() if v is not None and v != ""})
|
|
return create_mailbox(
|
|
provider=merged_extra.get("mail_provider", "laoudo"),
|
|
extra=merged_extra,
|
|
proxy=proxy,
|
|
)
|
|
|
|
def _do_one(i: int):
|
|
nonlocal next_start_time
|
|
try:
|
|
from core.proxy_pool import proxy_pool
|
|
|
|
_proxy = req.proxy
|
|
if not _proxy:
|
|
_proxy = proxy_pool.get_next()
|
|
if req.register_delay_seconds > 0:
|
|
with start_gate_lock:
|
|
now = time.time()
|
|
wait_seconds = max(0.0, next_start_time - now)
|
|
if wait_seconds > 0:
|
|
_log(task_id, f"第 {i+1} 个账号启动前延迟 {wait_seconds:g} 秒")
|
|
time.sleep(wait_seconds)
|
|
next_start_time = time.time() + req.register_delay_seconds
|
|
from core.config_store import config_store
|
|
merged_extra = config_store.get_all().copy()
|
|
merged_extra.update({k: v for k, v in req.extra.items() if v is not None and v != ""})
|
|
|
|
_config = RegisterConfig(
|
|
executor_type=req.executor_type,
|
|
captcha_solver=req.captcha_solver,
|
|
proxy=_proxy,
|
|
extra=merged_extra,
|
|
)
|
|
_mailbox = _build_mailbox(_proxy)
|
|
_platform = PlatformCls(config=_config, mailbox=_mailbox)
|
|
_platform._log_fn = lambda msg: _log(task_id, msg)
|
|
if getattr(_platform, "mailbox", None) is not None:
|
|
_platform.mailbox._log_fn = _platform._log_fn
|
|
with _tasks_lock:
|
|
_tasks[task_id]["progress"] = f"{i+1}/{req.count}"
|
|
_log(task_id, f"开始注册第 {i+1}/{req.count} 个账号")
|
|
if _proxy: _log(task_id, f"使用代理: {_proxy}")
|
|
account = _platform.register(
|
|
email=req.email or None,
|
|
password=req.password,
|
|
)
|
|
if isinstance(account.extra, dict):
|
|
mail_provider = merged_extra.get("mail_provider", "")
|
|
if mail_provider:
|
|
account.extra.setdefault("mail_provider", mail_provider)
|
|
if mail_provider == "luckmail" and req.platform == "chatgpt":
|
|
mailbox_token = getattr(_mailbox, "_token", "") or ""
|
|
if mailbox_token:
|
|
account.extra.setdefault("mailbox_token", mailbox_token)
|
|
if merged_extra.get("luckmail_project_code"):
|
|
account.extra.setdefault("luckmail_project_code", merged_extra.get("luckmail_project_code"))
|
|
if merged_extra.get("luckmail_email_type"):
|
|
account.extra.setdefault("luckmail_email_type", merged_extra.get("luckmail_email_type"))
|
|
if merged_extra.get("luckmail_domain"):
|
|
account.extra.setdefault("luckmail_domain", merged_extra.get("luckmail_domain"))
|
|
if merged_extra.get("luckmail_base_url"):
|
|
account.extra.setdefault("luckmail_base_url", merged_extra.get("luckmail_base_url"))
|
|
save_account(account)
|
|
if _proxy: proxy_pool.report_success(_proxy)
|
|
_log(task_id, f"✓ 注册成功: {account.email}")
|
|
_save_task_log(req.platform, account.email, "success")
|
|
_auto_upload_integrations(task_id, account)
|
|
cashier_url = (account.extra or {}).get("cashier_url", "")
|
|
if cashier_url:
|
|
_log(task_id, f" [升级链接] {cashier_url}")
|
|
with _tasks_lock:
|
|
_tasks[task_id].setdefault("cashier_urls", []).append(cashier_url)
|
|
return True
|
|
except Exception as e:
|
|
if _proxy: proxy_pool.report_fail(_proxy)
|
|
_log(task_id, f"✗ 注册失败: {e}")
|
|
_save_task_log(req.platform, req.email or "", "failed", error=str(e))
|
|
return str(e)
|
|
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
max_workers = min(req.concurrency, req.count, 5)
|
|
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
|
futures = [pool.submit(_do_one, i) for i in range(req.count)]
|
|
for f in as_completed(futures):
|
|
try:
|
|
result = f.result()
|
|
except Exception as e:
|
|
_log(task_id, f"✗ 任务线程异常: {e}")
|
|
errors.append(str(e))
|
|
continue
|
|
if result is True:
|
|
success += 1
|
|
else:
|
|
errors.append(result)
|
|
except Exception as e:
|
|
_log(task_id, f"致命错误: {e}")
|
|
with _tasks_lock:
|
|
_tasks[task_id]["status"] = "failed"
|
|
_tasks[task_id]["error"] = str(e)
|
|
return
|
|
|
|
with _tasks_lock:
|
|
_tasks[task_id]["status"] = "done"
|
|
_tasks[task_id]["success"] = success
|
|
_tasks[task_id]["errors"] = errors
|
|
_log(task_id, f"完成: 成功 {success} 个, 失败 {len(errors)} 个")
|
|
_cleanup_old_tasks()
|
|
|
|
|
|
@router.post("/register")
|
|
def create_register_task(
|
|
req: RegisterTaskRequest,
|
|
background_tasks: BackgroundTasks,
|
|
):
|
|
task_id = enqueue_register_task(req, background_tasks=background_tasks)
|
|
return {"task_id": task_id}
|
|
|
|
|
|
@router.get("/logs")
|
|
def get_logs(platform: str = None, page: int = 1, page_size: int = 50):
|
|
with Session(engine) as s:
|
|
q = select(TaskLog)
|
|
if platform:
|
|
q = q.where(TaskLog.platform == platform)
|
|
q = q.order_by(TaskLog.id.desc())
|
|
total = len(s.exec(q).all())
|
|
items = s.exec(q.offset((page - 1) * page_size).limit(page_size)).all()
|
|
return {"total": total, "items": items}
|
|
|
|
|
|
@router.post("/logs/batch-delete")
|
|
def batch_delete_logs(body: TaskLogBatchDeleteRequest):
|
|
if not body.ids:
|
|
raise HTTPException(400, "任务历史 ID 列表不能为空")
|
|
|
|
unique_ids = list(dict.fromkeys(body.ids))
|
|
if len(unique_ids) > 1000:
|
|
raise HTTPException(400, "单次最多删除 1000 条任务历史")
|
|
|
|
with Session(engine) as s:
|
|
try:
|
|
logs = s.exec(select(TaskLog).where(TaskLog.id.in_(unique_ids))).all()
|
|
found_ids = {log.id for log in logs if log.id is not None}
|
|
|
|
for log in logs:
|
|
s.delete(log)
|
|
|
|
s.commit()
|
|
deleted_count = len(found_ids)
|
|
not_found_ids = [log_id for log_id in unique_ids if log_id not in found_ids]
|
|
logger.info("批量删除任务历史成功: %s 条", deleted_count)
|
|
|
|
return {
|
|
"deleted": deleted_count,
|
|
"not_found": not_found_ids,
|
|
"total_requested": len(unique_ids),
|
|
}
|
|
except Exception as e:
|
|
s.rollback()
|
|
logger.exception("批量删除任务历史失败")
|
|
raise HTTPException(500, f"批量删除任务历史失败: {str(e)}")
|
|
|
|
|
|
@router.get("/{task_id}/logs/stream")
|
|
async def stream_logs(task_id: str, since: int = 0):
|
|
"""SSE 实时日志流"""
|
|
with _tasks_lock:
|
|
if task_id not in _tasks:
|
|
raise HTTPException(404, "任务不存在")
|
|
|
|
async def event_generator():
|
|
sent = since
|
|
while True:
|
|
with _tasks_lock:
|
|
logs = list(_tasks.get(task_id, {}).get("logs", []))
|
|
status = _tasks.get(task_id, {}).get("status", "")
|
|
while sent < len(logs):
|
|
yield f"data: {json.dumps({'line': logs[sent]})}\n\n"
|
|
sent += 1
|
|
if status in ("done", "failed"):
|
|
yield f"data: {json.dumps({'done': True, 'status': status})}\n\n"
|
|
break
|
|
await asyncio.sleep(0.5)
|
|
|
|
return StreamingResponse(
|
|
event_generator(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
|
|
@router.get("/{task_id}")
|
|
def get_task(task_id: str):
|
|
with _tasks_lock:
|
|
if task_id not in _tasks:
|
|
raise HTTPException(404, "任务不存在")
|
|
return _tasks[task_id]
|
|
|
|
|
|
@router.get("")
|
|
def list_tasks():
|
|
with _tasks_lock:
|
|
return list(_tasks.values())
|