AbstractPhil commited on
Commit
bbbffae
·
verified ·
1 Parent(s): 7527747

Update scripts/model_v4.py

Browse files
Files changed (1) hide show
  1. scripts/model_v4.py +29 -12
scripts/model_v4.py CHANGED
@@ -41,12 +41,12 @@ from pathlib import Path
41
  class TinyFluxConfig:
42
  """
43
  Configuration for TinyFlux-Deep v4.1 model.
44
-
45
  This config fully defines the model architecture and can be used to:
46
  1. Initialize a new model
47
- 2. Convert checkpoints between versions
48
  3. Validate checkpoint compatibility
49
-
50
  All dimension constraints are validated on creation.
51
  """
52
 
@@ -105,10 +105,10 @@ class TinyFluxConfig:
105
  f"num_attention_heads * attention_head_dim ({expected_hidden})"
106
  )
107
 
108
- # Validate RoPE dimensions
109
  if isinstance(self.axes_dims_rope, list):
110
  self.axes_dims_rope = tuple(self.axes_dims_rope)
111
-
112
  rope_sum = sum(self.axes_dims_rope)
113
  if rope_sum != self.attention_head_dim:
114
  raise ValueError(
@@ -158,11 +158,11 @@ class TinyFluxConfig:
158
  def validate_checkpoint(self, state_dict: Dict[str, torch.Tensor]) -> List[str]:
159
  """
160
  Validate that a checkpoint matches this config.
161
-
162
  Returns list of warnings (empty if perfect match).
163
  """
164
  warnings = []
165
-
166
  # Check double block count
167
  max_double = 0
168
  for key in state_dict:
@@ -171,7 +171,7 @@ class TinyFluxConfig:
171
  max_double = max(max_double, idx + 1)
172
  if max_double != self.num_double_layers:
173
  warnings.append(f"double_blocks: checkpoint has {max_double}, config expects {self.num_double_layers}")
174
-
175
  # Check single block count
176
  max_single = 0
177
  for key in state_dict:
@@ -180,25 +180,25 @@ class TinyFluxConfig:
180
  max_single = max(max_single, idx + 1)
181
  if max_single != self.num_single_layers:
182
  warnings.append(f"single_blocks: checkpoint has {max_single}, config expects {self.num_single_layers}")
183
-
184
  # Check hidden size from a known weight
185
  if "img_in.weight" in state_dict:
186
  w = state_dict["img_in.weight"]
187
  if w.shape[0] != self.hidden_size:
188
  warnings.append(f"hidden_size: checkpoint has {w.shape[0]}, config expects {self.hidden_size}")
189
-
190
  # Check for v4.1 components
191
  has_sol = any(k.startswith("sol_prior.") for k in state_dict)
192
  has_t5 = any(k.startswith("t5_pool.") for k in state_dict)
193
  has_lune = any(k.startswith("lune_predictor.") for k in state_dict)
194
-
195
  if self.use_sol_prior and not has_sol:
196
  warnings.append("config expects sol_prior but checkpoint missing it")
197
  if self.use_t5_vec and not has_t5:
198
  warnings.append("config expects t5_pool but checkpoint missing it")
199
  if self.use_lune_expert and not has_lune:
200
  warnings.append("config expects lune_predictor but checkpoint missing it")
201
-
202
  return warnings
203
 
204
 
@@ -1112,6 +1112,23 @@ class TinyFluxDeep(nn.Module):
1112
  if expert_features is not None and lune_features is None:
1113
  lune_features = expert_features
1114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1115
  # Input projections
1116
  img = self.img_in(hidden_states)
1117
  txt = self.txt_in(encoder_hidden_states)
 
41
  class TinyFluxConfig:
42
  """
43
  Configuration for TinyFlux-Deep v4.1 model.
44
+
45
  This config fully defines the model architecture and can be used to:
46
  1. Initialize a new model
47
+ 2. Convert checkpoints between versions
48
  3. Validate checkpoint compatibility
49
+
50
  All dimension constraints are validated on creation.
51
  """
52
 
 
105
  f"num_attention_heads * attention_head_dim ({expected_hidden})"
106
  )
107
 
108
+ # Validate RoPE dimensions
109
  if isinstance(self.axes_dims_rope, list):
110
  self.axes_dims_rope = tuple(self.axes_dims_rope)
111
+
112
  rope_sum = sum(self.axes_dims_rope)
113
  if rope_sum != self.attention_head_dim:
114
  raise ValueError(
 
158
  def validate_checkpoint(self, state_dict: Dict[str, torch.Tensor]) -> List[str]:
159
  """
160
  Validate that a checkpoint matches this config.
161
+
162
  Returns list of warnings (empty if perfect match).
163
  """
164
  warnings = []
165
+
166
  # Check double block count
167
  max_double = 0
168
  for key in state_dict:
 
171
  max_double = max(max_double, idx + 1)
172
  if max_double != self.num_double_layers:
173
  warnings.append(f"double_blocks: checkpoint has {max_double}, config expects {self.num_double_layers}")
174
+
175
  # Check single block count
176
  max_single = 0
177
  for key in state_dict:
 
180
  max_single = max(max_single, idx + 1)
181
  if max_single != self.num_single_layers:
182
  warnings.append(f"single_blocks: checkpoint has {max_single}, config expects {self.num_single_layers}")
183
+
184
  # Check hidden size from a known weight
185
  if "img_in.weight" in state_dict:
186
  w = state_dict["img_in.weight"]
187
  if w.shape[0] != self.hidden_size:
188
  warnings.append(f"hidden_size: checkpoint has {w.shape[0]}, config expects {self.hidden_size}")
189
+
190
  # Check for v4.1 components
191
  has_sol = any(k.startswith("sol_prior.") for k in state_dict)
192
  has_t5 = any(k.startswith("t5_pool.") for k in state_dict)
193
  has_lune = any(k.startswith("lune_predictor.") for k in state_dict)
194
+
195
  if self.use_sol_prior and not has_sol:
196
  warnings.append("config expects sol_prior but checkpoint missing it")
197
  if self.use_t5_vec and not has_t5:
198
  warnings.append("config expects t5_pool but checkpoint missing it")
199
  if self.use_lune_expert and not has_lune:
200
  warnings.append("config expects lune_predictor but checkpoint missing it")
201
+
202
  return warnings
203
 
204
 
 
1112
  if expert_features is not None and lune_features is None:
1113
  lune_features = expert_features
1114
 
1115
+ # Ensure consistent dtype (text encoders often output float32)
1116
+ model_dtype = self.img_in.weight.dtype
1117
+ hidden_states = hidden_states.to(dtype=model_dtype)
1118
+ encoder_hidden_states = encoder_hidden_states.to(dtype=model_dtype)
1119
+ pooled_projections = pooled_projections.to(dtype=model_dtype)
1120
+ timestep = timestep.to(dtype=model_dtype)
1121
+
1122
+ # Cast optional expert inputs if provided
1123
+ if lune_features is not None:
1124
+ lune_features = lune_features.to(dtype=model_dtype)
1125
+ if sol_stats is not None:
1126
+ sol_stats = sol_stats.to(dtype=model_dtype)
1127
+ if sol_spatial is not None:
1128
+ sol_spatial = sol_spatial.to(dtype=model_dtype)
1129
+ if guidance is not None:
1130
+ guidance = guidance.to(dtype=model_dtype)
1131
+
1132
  # Input projections
1133
  img = self.img_in(hidden_states)
1134
  txt = self.txt_in(encoder_hidden_states)