mirror of
https://github.com/zc-zhangchen/any-auto-register.git
synced 2026-05-08 08:14:04 +08:00
293 lines
8.5 KiB
Python
293 lines
8.5 KiB
Python
"""注册任务运行时控制与状态存储。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
import threading
|
|
import time
|
|
from typing import Any
|
|
|
|
|
|
class TaskInterruption(RuntimeError):
|
|
"""任务执行过程中触发的协作式中断。"""
|
|
|
|
|
|
class StopTaskRequested(TaskInterruption):
|
|
"""整个任务被手动停止。"""
|
|
|
|
def __init__(self, message: str = "任务已手动停止"):
|
|
super().__init__(message)
|
|
|
|
|
|
class SkipCurrentAttemptRequested(TaskInterruption):
|
|
"""当前账号被手动跳过。"""
|
|
|
|
def __init__(self, message: str = "已手动跳过当前账号"):
|
|
super().__init__(message)
|
|
|
|
|
|
class AttemptOutcome(str, Enum):
|
|
SUCCESS = "success"
|
|
FAILED = "failed"
|
|
SKIPPED = "skipped"
|
|
STOPPED = "stopped"
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class AttemptResult:
|
|
outcome: AttemptOutcome
|
|
message: str = ""
|
|
|
|
@classmethod
|
|
def success(cls) -> "AttemptResult":
|
|
return cls(AttemptOutcome.SUCCESS)
|
|
|
|
@classmethod
|
|
def failed(cls, message: str) -> "AttemptResult":
|
|
return cls(AttemptOutcome.FAILED, message)
|
|
|
|
@classmethod
|
|
def skipped(cls, message: str) -> "AttemptResult":
|
|
return cls(AttemptOutcome.SKIPPED, message)
|
|
|
|
@classmethod
|
|
def stopped(cls, message: str) -> "AttemptResult":
|
|
return cls(AttemptOutcome.STOPPED, message)
|
|
|
|
|
|
class RegisterTaskControl:
|
|
"""协作式任务控制器:支持停止整个任务、跳过一个当前账号。"""
|
|
|
|
def __init__(self):
|
|
self._lock = threading.Lock()
|
|
self._stop_requested = False
|
|
self._pending_skip_requests = 0
|
|
|
|
def request_stop(self) -> None:
|
|
with self._lock:
|
|
self._stop_requested = True
|
|
|
|
def request_skip_current(self) -> None:
|
|
with self._lock:
|
|
self._pending_skip_requests += 1
|
|
|
|
def checkpoint(self, *, consume_skip: bool = True) -> None:
|
|
with self._lock:
|
|
if self._stop_requested:
|
|
raise StopTaskRequested()
|
|
if consume_skip and self._pending_skip_requests > 0:
|
|
self._pending_skip_requests -= 1
|
|
raise SkipCurrentAttemptRequested()
|
|
|
|
def is_stop_requested(self) -> bool:
|
|
with self._lock:
|
|
return self._stop_requested
|
|
|
|
def snapshot(self) -> dict[str, Any]:
|
|
with self._lock:
|
|
return {
|
|
"stop_requested": self._stop_requested,
|
|
"pending_skip_requests": self._pending_skip_requests,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class RegisterTaskRecord:
|
|
id: str
|
|
platform: str
|
|
source: str
|
|
total: int
|
|
meta: dict[str, Any] = field(default_factory=dict)
|
|
status: str = "pending"
|
|
progress: str = "0/0"
|
|
logs: list[str] = field(default_factory=list)
|
|
success: int = 0
|
|
skipped: int = 0
|
|
errors: list[str] = field(default_factory=list)
|
|
cashier_urls: list[str] = field(default_factory=list)
|
|
error: str = ""
|
|
created_at: float = field(default_factory=time.time)
|
|
updated_at: float = field(default_factory=time.time)
|
|
control: RegisterTaskControl = field(
|
|
default_factory=RegisterTaskControl,
|
|
repr=False,
|
|
)
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
data = {
|
|
"id": self.id,
|
|
"status": self.status,
|
|
"platform": self.platform,
|
|
"source": self.source,
|
|
"meta": dict(self.meta),
|
|
"progress": self.progress,
|
|
"logs": list(self.logs),
|
|
"success": self.success,
|
|
"skipped": self.skipped,
|
|
"errors": list(self.errors),
|
|
"control": self.control.snapshot(),
|
|
}
|
|
if self.cashier_urls:
|
|
data["cashier_urls"] = list(self.cashier_urls)
|
|
if self.error:
|
|
data["error"] = self.error
|
|
return data
|
|
|
|
|
|
class RegisterTaskStore:
|
|
"""线程安全的注册任务存储。"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
max_finished_tasks: int = 200,
|
|
cleanup_threshold: int = 250,
|
|
):
|
|
self._lock = threading.Lock()
|
|
self._records: dict[str, RegisterTaskRecord] = {}
|
|
self.max_finished_tasks = max_finished_tasks
|
|
self.cleanup_threshold = cleanup_threshold
|
|
|
|
def create(
|
|
self,
|
|
task_id: str,
|
|
*,
|
|
platform: str,
|
|
total: int,
|
|
source: str,
|
|
meta: dict[str, Any] | None = None,
|
|
) -> RegisterTaskRecord:
|
|
with self._lock:
|
|
record = RegisterTaskRecord(
|
|
id=task_id,
|
|
platform=platform,
|
|
total=total,
|
|
source=source,
|
|
meta=dict(meta or {}),
|
|
progress=f"0/{total}",
|
|
)
|
|
self._records[task_id] = record
|
|
return record
|
|
|
|
def exists(self, task_id: str) -> bool:
|
|
with self._lock:
|
|
return task_id in self._records
|
|
|
|
def has_active(
|
|
self,
|
|
*,
|
|
platform: str | None = None,
|
|
source: str | None = None,
|
|
) -> bool:
|
|
with self._lock:
|
|
for record in self._records.values():
|
|
if record.status not in ("pending", "running"):
|
|
continue
|
|
if platform and record.platform != platform:
|
|
continue
|
|
if source and record.source != source:
|
|
continue
|
|
return True
|
|
return False
|
|
|
|
def control_for(self, task_id: str) -> RegisterTaskControl:
|
|
with self._lock:
|
|
return self._records[task_id].control
|
|
|
|
def request_stop(self, task_id: str) -> dict[str, Any]:
|
|
control = self.control_for(task_id)
|
|
control.request_stop()
|
|
return control.snapshot()
|
|
|
|
def request_skip_current(self, task_id: str) -> dict[str, Any]:
|
|
control = self.control_for(task_id)
|
|
control.request_skip_current()
|
|
return control.snapshot()
|
|
|
|
def append_log(self, task_id: str, entry: str) -> None:
|
|
with self._lock:
|
|
record = self._records.get(task_id)
|
|
if record is None:
|
|
return
|
|
record.logs.append(entry)
|
|
record.updated_at = time.time()
|
|
|
|
def mark_running(self, task_id: str) -> None:
|
|
with self._lock:
|
|
record = self._records[task_id]
|
|
record.status = "running"
|
|
record.updated_at = time.time()
|
|
|
|
def set_progress(self, task_id: str, progress: str) -> None:
|
|
with self._lock:
|
|
record = self._records[task_id]
|
|
record.progress = progress
|
|
record.updated_at = time.time()
|
|
|
|
def add_cashier_url(self, task_id: str, cashier_url: str) -> None:
|
|
with self._lock:
|
|
record = self._records[task_id]
|
|
record.cashier_urls.append(cashier_url)
|
|
record.updated_at = time.time()
|
|
|
|
def finish(
|
|
self,
|
|
task_id: str,
|
|
*,
|
|
status: str,
|
|
success: int,
|
|
skipped: int,
|
|
errors: list[str],
|
|
error: str = "",
|
|
) -> None:
|
|
with self._lock:
|
|
record = self._records[task_id]
|
|
record.status = status
|
|
record.success = success
|
|
record.skipped = skipped
|
|
record.errors = list(errors)
|
|
record.error = error
|
|
record.updated_at = time.time()
|
|
|
|
def snapshot(self, task_id: str) -> dict[str, Any]:
|
|
with self._lock:
|
|
return self._records[task_id].to_dict()
|
|
|
|
def list_snapshots(self) -> list[dict[str, Any]]:
|
|
with self._lock:
|
|
return [record.to_dict() for record in self._records.values()]
|
|
|
|
def log_state(self, task_id: str) -> tuple[list[str], str]:
|
|
with self._lock:
|
|
record = self._records[task_id]
|
|
return list(record.logs), record.status
|
|
|
|
def cleanup(self) -> None:
|
|
with self._lock:
|
|
if len(self._records) <= self.cleanup_threshold:
|
|
return
|
|
finished = [
|
|
(task_id, record)
|
|
for task_id, record in self._records.items()
|
|
if record.status in ("done", "failed", "stopped")
|
|
]
|
|
if len(finished) <= self.max_finished_tasks:
|
|
return
|
|
finished.sort(key=lambda item: item[1].created_at)
|
|
to_remove = finished[: len(finished) - self.max_finished_tasks]
|
|
for task_id, _ in to_remove:
|
|
self._records.pop(task_id, None)
|
|
|
|
|
|
__all__ = [
|
|
"AttemptOutcome",
|
|
"AttemptResult",
|
|
"RegisterTaskControl",
|
|
"RegisterTaskRecord",
|
|
"RegisterTaskStore",
|
|
"SkipCurrentAttemptRequested",
|
|
"StopTaskRequested",
|
|
"TaskInterruption",
|
|
]
|