File size: 9,446 Bytes
d8af39a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
diff --git a/fastvideo/dataset/__init__.py b/fastvideo/dataset/__init__.py
index b82c653..f76077b 100644
--- a/fastvideo/dataset/__init__.py
+++ b/fastvideo/dataset/__init__.py
@@ -4,6 +4,8 @@ from torchvision.transforms import Lambda
 
 from fastvideo.dataset.parquet_dataset_map_style import (
     build_parquet_map_style_dataloader)
+from fastvideo.dataset.pants_latent_dataset import (
+    build_pants_latent_dataloader, is_pants_latent_path)
 from fastvideo.dataset.ltx2_precomputed_dataset import (
     build_ltx2_precomputed_dataloader, LTX2PrecomputedDataset)
 from fastvideo.dataset.preprocessing_datasets import VideoCaptionMergedDataset, TextDataset
@@ -46,6 +48,8 @@ def gettextdataset(args) -> TextDataset:
 
 __all__ = [
     "build_parquet_map_style_dataloader",
+    "build_pants_latent_dataloader",
+    "is_pants_latent_path",
     "build_ltx2_precomputed_dataloader",
     "LTX2PrecomputedDataset",
     "ValidationDataset",
diff --git a/fastvideo/training/training_pipeline.py b/fastvideo/training/training_pipeline.py
index 575d6dc..140bb31 100644
--- a/fastvideo/training/training_pipeline.py
+++ b/fastvideo/training/training_pipeline.py
@@ -28,7 +28,11 @@ try:
 except Exception:
     pass
 from fastvideo.api.sampling_param import SamplingParam
-from fastvideo.dataset import build_parquet_map_style_dataloader
+from fastvideo.dataset import (
+    build_pants_latent_dataloader,
+    build_parquet_map_style_dataloader,
+    is_pants_latent_path,
+)
 from fastvideo.dataset.dataloader.schema import pyarrow_schema_t2v
 from fastvideo.dataset.validation_dataset import ValidationDataset
 from fastvideo.distributed import (cleanup_dist_env_and_memory, get_local_torch_device, get_sp_group, get_world_group)
@@ -42,7 +46,8 @@ from fastvideo.training.activation_checkpoint import (apply_activation_checkpoin
 from fastvideo.training.trackers import (DummyTracker, TrackerType, initialize_trackers, Trackers)
 from fastvideo.training.training_utils import (clip_grad_norm_while_handling_failing_dtensor_cases,
                                                compute_density_for_timestep_sampling, count_trainable, get_scheduler,
-                                               get_sigmas, load_checkpoint, normalize_dit_input, save_checkpoint)
+                                               get_sigmas, load_checkpoint, normalize_dit_input, save_checkpoint,
+                                               EMA_FSDP, gather_state_dict_on_cpu_rank0, custom_to_hf_state_dict)
 from fastvideo.utils import (is_vmoba_available, is_vsa_available, set_random_seed, shallow_asdict)
 
 try:
@@ -82,6 +87,7 @@ class TrainingPipeline(LoRAPipeline, ABC):
         super().__init__(model_path, fastvideo_args, required_config_modules, loaded_modules)  # type: ignore
         self.tracker = DummyTracker()
         self.validation_ref_videos_logged = False
+        self.generator_ema: EMA_FSDP | None = None
 
     def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
         raise RuntimeError("create_pipeline_stages should not be called for training pipeline")
@@ -167,16 +173,27 @@ class TrainingPipeline(LoRAPipeline, ABC):
                 last_epoch=self.init_steps - 1,
             )
 
-        self.train_dataset, self.train_dataloader = build_parquet_map_style_dataloader(
-            training_args.data_path,
-            training_args.train_batch_size,
-            parquet_schema=self.train_dataset_schema,
-            num_data_workers=training_args.dataloader_num_workers,
-            cfg_rate=training_args.training_cfg_rate,
-            drop_last=True,
-            text_padding_length=training_args.pipeline_config.text_encoder_configs[0].arch_config.
-            text_len,  # type: ignore[attr-defined]
-            seed=self.seed)
+        text_padding_length = training_args.pipeline_config.text_encoder_configs[0].arch_config.text_len  # type: ignore[attr-defined]
+        if is_pants_latent_path(training_args.data_path):
+            self.train_dataset, self.train_dataloader = build_pants_latent_dataloader(
+                training_args.data_path,
+                training_args.train_batch_size,
+                num_data_workers=training_args.dataloader_num_workers,
+                cfg_rate=training_args.training_cfg_rate,
+                drop_last=True,
+                text_padding_length=text_padding_length,
+                seed=self.seed,
+            )
+        else:
+            self.train_dataset, self.train_dataloader = build_parquet_map_style_dataloader(
+                training_args.data_path,
+                training_args.train_batch_size,
+                parquet_schema=self.train_dataset_schema,
+                num_data_workers=training_args.dataloader_num_workers,
+                cfg_rate=training_args.training_cfg_rate,
+                drop_last=True,
+                text_padding_length=text_padding_length,
+                seed=self.seed)
 
         self.noise_scheduler = noise_scheduler
         if self.training_args.boundary_ratio is not None:
@@ -460,6 +477,43 @@ class TrainingPipeline(LoRAPipeline, ABC):
         training_batch.grad_norm = grad_norm
         return training_batch
 
+    def _maybe_init_ema(self, step: int) -> None:
+        if not self.training_args.use_ema:
+            return
+        if self.generator_ema is not None:
+            return
+        if step < self.training_args.ema_start_step:
+            return
+        if self.training_args.ema_decay <= 0:
+            return
+        self.generator_ema = EMA_FSDP(self.transformer, decay=self.training_args.ema_decay)
+        logger.info("Created generator EMA at step %s with decay=%s", step, self.training_args.ema_decay)
+
+    def _maybe_update_ema(self, step: int) -> None:
+        self._maybe_init_ema(step)
+        if self.generator_ema is not None:
+            self.generator_ema.update(self.transformer)
+
+    def _save_ema_weights(self, step: int) -> None:
+        if not self.training_args.use_ema or self.generator_ema is None:
+            return
+        ema_dir = os.path.join(self.training_args.output_dir, f"ema_checkpoint-{step}")
+        os.makedirs(ema_dir, exist_ok=True)
+        with self.generator_ema.apply_to_model(self.transformer):
+            cpu_state = gather_state_dict_on_cpu_rank0(self.transformer, device=None)
+        if self.global_rank == 0:
+            from safetensors.torch import save_file
+
+            diffusers_state_dict = custom_to_hf_state_dict(
+                cpu_state,
+                self.transformer.reverse_param_names_mapping,
+            )
+            save_file(
+                diffusers_state_dict,
+                os.path.join(ema_dir, "diffusion_pytorch_model.safetensors"),
+            )
+            logger.info("Saved EMA transformer weights to %s", ema_dir)
+
     @profile_region("profiler_region_training_train_one_step")
     def train_one_step(self, training_batch: TrainingBatch) -> TrainingBatch:
         training_batch = self._prepare_training(training_batch)
@@ -571,6 +625,7 @@ class TrainingPipeline(LoRAPipeline, ABC):
             training_batch.current_timestep = step
             training_batch.current_vsa_sparsity = current_vsa_sparsity
             training_batch = self.train_one_step(training_batch)
+            self._maybe_update_ema(step)
 
             loss = float(training_batch.total_loss)
             grad_norm = training_batch.grad_norm
@@ -594,6 +649,9 @@ class TrainingPipeline(LoRAPipeline, ABC):
                     "grad_norm": grad_norm,
                     "vsa_sparsity": current_vsa_sparsity,
                 }
+                if self.training_args.use_ema:
+                    metrics["ema_enabled"] = self.generator_ema is not None
+                    metrics["ema_decay"] = self.training_args.ema_decay
                 try:
                     metrics["batch_size"] = int(training_batch.raw_latent_shape[0])
 
@@ -622,6 +680,7 @@ class TrainingPipeline(LoRAPipeline, ABC):
                     save_checkpoint(self.transformer, self.global_rank, self.training_args.output_dir, step,
                                     self.optimizer, self.train_dataloader, self.lr_scheduler,
                                     self.noise_random_generator)
+                    self._save_ema_weights(step)
                 self.transformer.train()
                 self.sp_group.barrier()
 
@@ -637,9 +696,13 @@ class TrainingPipeline(LoRAPipeline, ABC):
                                 trainable_params)
 
         self.tracker.finish()
-        save_checkpoint(self.transformer, self.global_rank, self.training_args.output_dir,
-                        self.training_args.max_train_steps, self.optimizer, self.train_dataloader, self.lr_scheduler,
-                        self.noise_random_generator)
+        if os.environ.get("FASTVIDEO_SKIP_FINAL_CHECKPOINT", "0") == "1":
+            logger.info("Skipping final checkpoint because FASTVIDEO_SKIP_FINAL_CHECKPOINT=1")
+        else:
+            save_checkpoint(self.transformer, self.global_rank, self.training_args.output_dir,
+                            self.training_args.max_train_steps, self.optimizer, self.train_dataloader,
+                            self.lr_scheduler, self.noise_random_generator)
+            self._save_ema_weights(self.training_args.max_train_steps)
 
         if envs.FASTVIDEO_TORCH_PROFILER_DIR:
             logger.info("Stopping profiler...")