|
|
""" |
|
|
领域模型模块 |
|
|
|
|
|
定义训练任务相关的核心数据结构 |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
from datetime import datetime |
|
|
from enum import Enum |
|
|
from typing import Dict, Optional, Any |
|
|
|
|
|
|
|
|
class TaskStatus(Enum): |
|
|
"""任务状态枚举""" |
|
|
QUEUED = "queued" |
|
|
RUNNING = "running" |
|
|
COMPLETED = "completed" |
|
|
FAILED = "failed" |
|
|
CANCELLED = "cancelled" |
|
|
INTERRUPTED = "interrupted" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Task: |
|
|
""" |
|
|
训练任务领域模型 |
|
|
|
|
|
Attributes: |
|
|
id: 任务唯一标识 |
|
|
job_id: 队列作业ID(由任务队列生成) |
|
|
exp_name: 实验名称 |
|
|
status: 任务状态 |
|
|
config: 任务配置(包含所有训练参数) |
|
|
current_stage: 当前执行阶段 |
|
|
progress: 总体进度 (0.0-1.0) |
|
|
stage_progress: 当前阶段进度 (0.0-1.0) |
|
|
message: 最新状态消息 |
|
|
error_message: 错误信息(失败时) |
|
|
created_at: 创建时间 |
|
|
started_at: 开始执行时间 |
|
|
completed_at: 完成时间 |
|
|
|
|
|
Example: |
|
|
>>> task = Task( |
|
|
... id="task-123", |
|
|
... exp_name="my_voice", |
|
|
... config={"version": "v2", "batch_size": 4} |
|
|
... ) |
|
|
>>> task.status |
|
|
<TaskStatus.QUEUED: 'queued'> |
|
|
""" |
|
|
id: str |
|
|
exp_name: str |
|
|
config: Dict[str, Any] |
|
|
job_id: Optional[str] = None |
|
|
status: TaskStatus = TaskStatus.QUEUED |
|
|
current_stage: Optional[str] = None |
|
|
progress: float = 0.0 |
|
|
stage_progress: float = 0.0 |
|
|
message: Optional[str] = None |
|
|
error_message: Optional[str] = None |
|
|
created_at: datetime = field(default_factory=datetime.utcnow) |
|
|
started_at: Optional[datetime] = None |
|
|
completed_at: Optional[datetime] = None |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""转换为字典""" |
|
|
return { |
|
|
"id": self.id, |
|
|
"job_id": self.job_id, |
|
|
"exp_name": self.exp_name, |
|
|
"status": self.status.value, |
|
|
"config": self.config, |
|
|
"current_stage": self.current_stage, |
|
|
"progress": self.progress, |
|
|
"stage_progress": self.stage_progress, |
|
|
"message": self.message, |
|
|
"error_message": self.error_message, |
|
|
"created_at": self.created_at.isoformat() if self.created_at else None, |
|
|
"started_at": self.started_at.isoformat() if self.started_at else None, |
|
|
"completed_at": self.completed_at.isoformat() if self.completed_at else None, |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, data: Dict[str, Any]) -> "Task": |
|
|
"""从字典创建实例""" |
|
|
|
|
|
status = data.get("status", "queued") |
|
|
if isinstance(status, str): |
|
|
status = TaskStatus(status) |
|
|
|
|
|
|
|
|
def parse_datetime(value): |
|
|
if value is None: |
|
|
return None |
|
|
if isinstance(value, datetime): |
|
|
return value |
|
|
return datetime.fromisoformat(value) |
|
|
|
|
|
return cls( |
|
|
id=data["id"], |
|
|
job_id=data.get("job_id"), |
|
|
exp_name=data["exp_name"], |
|
|
status=status, |
|
|
config=data.get("config", {}), |
|
|
current_stage=data.get("current_stage"), |
|
|
progress=data.get("progress", 0.0), |
|
|
stage_progress=data.get("stage_progress", 0.0), |
|
|
message=data.get("message"), |
|
|
error_message=data.get("error_message"), |
|
|
created_at=parse_datetime(data.get("created_at")), |
|
|
started_at=parse_datetime(data.get("started_at")), |
|
|
completed_at=parse_datetime(data.get("completed_at")), |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ProgressInfo: |
|
|
""" |
|
|
进度信息数据结构 |
|
|
|
|
|
用于在子进程和主进程之间传递进度更新 |
|
|
|
|
|
Attributes: |
|
|
type: 消息类型 ("progress", "log", "error", "heartbeat") |
|
|
stage: 当前阶段名称 |
|
|
stage_index: 当前阶段索引 |
|
|
total_stages: 总阶段数 |
|
|
progress: 阶段内进度 (0.0-1.0) |
|
|
overall_progress: 总体进度 (0.0-1.0) |
|
|
message: 进度消息 |
|
|
status: 状态 |
|
|
data: 附加数据 |
|
|
""" |
|
|
type: str = "progress" |
|
|
stage: Optional[str] = None |
|
|
stage_index: Optional[int] = None |
|
|
total_stages: Optional[int] = None |
|
|
progress: float = 0.0 |
|
|
overall_progress: float = 0.0 |
|
|
message: Optional[str] = None |
|
|
status: Optional[str] = None |
|
|
data: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""转换为字典""" |
|
|
return { |
|
|
"type": self.type, |
|
|
"stage": self.stage, |
|
|
"stage_index": self.stage_index, |
|
|
"total_stages": self.total_stages, |
|
|
"progress": self.progress, |
|
|
"overall_progress": self.overall_progress, |
|
|
"message": self.message, |
|
|
"status": self.status, |
|
|
"data": self.data, |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, data: Dict[str, Any]) -> "ProgressInfo": |
|
|
"""从字典创建实例""" |
|
|
return cls( |
|
|
type=data.get("type", "progress"), |
|
|
stage=data.get("stage"), |
|
|
stage_index=data.get("stage_index"), |
|
|
total_stages=data.get("total_stages"), |
|
|
progress=data.get("progress", 0.0), |
|
|
overall_progress=data.get("overall_progress", 0.0), |
|
|
message=data.get("message"), |
|
|
status=data.get("status"), |
|
|
data=data.get("data", {}), |
|
|
) |
|
|
|