File size: 5,710 Bytes
e43edbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
领域模型模块

定义训练任务相关的核心数据结构
"""

from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Dict, Optional, Any


class TaskStatus(Enum):
    """任务状态枚举"""
    QUEUED = "queued"           # 已入队,等待执行
    RUNNING = "running"         # 执行中
    COMPLETED = "completed"     # 已完成
    FAILED = "failed"           # 失败
    CANCELLED = "cancelled"     # 已取消
    INTERRUPTED = "interrupted" # 被中断(应用重启时运行中的任务)


@dataclass
class Task:
    """
    训练任务领域模型
    
    Attributes:
        id: 任务唯一标识
        job_id: 队列作业ID(由任务队列生成)
        exp_name: 实验名称
        status: 任务状态
        config: 任务配置(包含所有训练参数)
        current_stage: 当前执行阶段
        progress: 总体进度 (0.0-1.0)
        stage_progress: 当前阶段进度 (0.0-1.0)
        message: 最新状态消息
        error_message: 错误信息(失败时)
        created_at: 创建时间
        started_at: 开始执行时间
        completed_at: 完成时间
        
    Example:
        >>> task = Task(
        ...     id="task-123",
        ...     exp_name="my_voice",
        ...     config={"version": "v2", "batch_size": 4}
        ... )
        >>> task.status
        <TaskStatus.QUEUED: 'queued'>
    """
    id: str
    exp_name: str
    config: Dict[str, Any]
    job_id: Optional[str] = None
    status: TaskStatus = TaskStatus.QUEUED
    current_stage: Optional[str] = None
    progress: float = 0.0
    stage_progress: float = 0.0
    message: Optional[str] = None
    error_message: Optional[str] = None
    created_at: datetime = field(default_factory=datetime.utcnow)
    started_at: Optional[datetime] = None
    completed_at: Optional[datetime] = None
    
    def to_dict(self) -> Dict[str, Any]:
        """转换为字典"""
        return {
            "id": self.id,
            "job_id": self.job_id,
            "exp_name": self.exp_name,
            "status": self.status.value,
            "config": self.config,
            "current_stage": self.current_stage,
            "progress": self.progress,
            "stage_progress": self.stage_progress,
            "message": self.message,
            "error_message": self.error_message,
            "created_at": self.created_at.isoformat() if self.created_at else None,
            "started_at": self.started_at.isoformat() if self.started_at else None,
            "completed_at": self.completed_at.isoformat() if self.completed_at else None,
        }
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "Task":
        """从字典创建实例"""
        # 处理状态枚举
        status = data.get("status", "queued")
        if isinstance(status, str):
            status = TaskStatus(status)
        
        # 处理日期时间
        def parse_datetime(value):
            if value is None:
                return None
            if isinstance(value, datetime):
                return value
            return datetime.fromisoformat(value)
        
        return cls(
            id=data["id"],
            job_id=data.get("job_id"),
            exp_name=data["exp_name"],
            status=status,
            config=data.get("config", {}),
            current_stage=data.get("current_stage"),
            progress=data.get("progress", 0.0),
            stage_progress=data.get("stage_progress", 0.0),
            message=data.get("message"),
            error_message=data.get("error_message"),
            created_at=parse_datetime(data.get("created_at")),
            started_at=parse_datetime(data.get("started_at")),
            completed_at=parse_datetime(data.get("completed_at")),
        )


@dataclass
class ProgressInfo:
    """
    进度信息数据结构
    
    用于在子进程和主进程之间传递进度更新
    
    Attributes:
        type: 消息类型 ("progress", "log", "error", "heartbeat")
        stage: 当前阶段名称
        stage_index: 当前阶段索引
        total_stages: 总阶段数
        progress: 阶段内进度 (0.0-1.0)
        overall_progress: 总体进度 (0.0-1.0)
        message: 进度消息
        status: 状态
        data: 附加数据
    """
    type: str = "progress"
    stage: Optional[str] = None
    stage_index: Optional[int] = None
    total_stages: Optional[int] = None
    progress: float = 0.0
    overall_progress: float = 0.0
    message: Optional[str] = None
    status: Optional[str] = None
    data: Dict[str, Any] = field(default_factory=dict)
    
    def to_dict(self) -> Dict[str, Any]:
        """转换为字典"""
        return {
            "type": self.type,
            "stage": self.stage,
            "stage_index": self.stage_index,
            "total_stages": self.total_stages,
            "progress": self.progress,
            "overall_progress": self.overall_progress,
            "message": self.message,
            "status": self.status,
            "data": self.data,
        }
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "ProgressInfo":
        """从字典创建实例"""
        return cls(
            type=data.get("type", "progress"),
            stage=data.get("stage"),
            stage_index=data.get("stage_index"),
            total_stages=data.get("total_stages"),
            progress=data.get("progress", 0.0),
            overall_progress=data.get("overall_progress", 0.0),
            message=data.get("message"),
            status=data.get("status"),
            data=data.get("data", {}),
        )