Update scripts/model_v4.py
Browse files- scripts/model_v4.py +29 -12
scripts/model_v4.py
CHANGED
|
@@ -41,12 +41,12 @@ from pathlib import Path
|
|
| 41 |
class TinyFluxConfig:
|
| 42 |
"""
|
| 43 |
Configuration for TinyFlux-Deep v4.1 model.
|
| 44 |
-
|
| 45 |
This config fully defines the model architecture and can be used to:
|
| 46 |
1. Initialize a new model
|
| 47 |
-
2. Convert checkpoints between versions
|
| 48 |
3. Validate checkpoint compatibility
|
| 49 |
-
|
| 50 |
All dimension constraints are validated on creation.
|
| 51 |
"""
|
| 52 |
|
|
@@ -105,10 +105,10 @@ class TinyFluxConfig:
|
|
| 105 |
f"num_attention_heads * attention_head_dim ({expected_hidden})"
|
| 106 |
)
|
| 107 |
|
| 108 |
-
# Validate RoPE dimensions
|
| 109 |
if isinstance(self.axes_dims_rope, list):
|
| 110 |
self.axes_dims_rope = tuple(self.axes_dims_rope)
|
| 111 |
-
|
| 112 |
rope_sum = sum(self.axes_dims_rope)
|
| 113 |
if rope_sum != self.attention_head_dim:
|
| 114 |
raise ValueError(
|
|
@@ -158,11 +158,11 @@ class TinyFluxConfig:
|
|
| 158 |
def validate_checkpoint(self, state_dict: Dict[str, torch.Tensor]) -> List[str]:
|
| 159 |
"""
|
| 160 |
Validate that a checkpoint matches this config.
|
| 161 |
-
|
| 162 |
Returns list of warnings (empty if perfect match).
|
| 163 |
"""
|
| 164 |
warnings = []
|
| 165 |
-
|
| 166 |
# Check double block count
|
| 167 |
max_double = 0
|
| 168 |
for key in state_dict:
|
|
@@ -171,7 +171,7 @@ class TinyFluxConfig:
|
|
| 171 |
max_double = max(max_double, idx + 1)
|
| 172 |
if max_double != self.num_double_layers:
|
| 173 |
warnings.append(f"double_blocks: checkpoint has {max_double}, config expects {self.num_double_layers}")
|
| 174 |
-
|
| 175 |
# Check single block count
|
| 176 |
max_single = 0
|
| 177 |
for key in state_dict:
|
|
@@ -180,25 +180,25 @@ class TinyFluxConfig:
|
|
| 180 |
max_single = max(max_single, idx + 1)
|
| 181 |
if max_single != self.num_single_layers:
|
| 182 |
warnings.append(f"single_blocks: checkpoint has {max_single}, config expects {self.num_single_layers}")
|
| 183 |
-
|
| 184 |
# Check hidden size from a known weight
|
| 185 |
if "img_in.weight" in state_dict:
|
| 186 |
w = state_dict["img_in.weight"]
|
| 187 |
if w.shape[0] != self.hidden_size:
|
| 188 |
warnings.append(f"hidden_size: checkpoint has {w.shape[0]}, config expects {self.hidden_size}")
|
| 189 |
-
|
| 190 |
# Check for v4.1 components
|
| 191 |
has_sol = any(k.startswith("sol_prior.") for k in state_dict)
|
| 192 |
has_t5 = any(k.startswith("t5_pool.") for k in state_dict)
|
| 193 |
has_lune = any(k.startswith("lune_predictor.") for k in state_dict)
|
| 194 |
-
|
| 195 |
if self.use_sol_prior and not has_sol:
|
| 196 |
warnings.append("config expects sol_prior but checkpoint missing it")
|
| 197 |
if self.use_t5_vec and not has_t5:
|
| 198 |
warnings.append("config expects t5_pool but checkpoint missing it")
|
| 199 |
if self.use_lune_expert and not has_lune:
|
| 200 |
warnings.append("config expects lune_predictor but checkpoint missing it")
|
| 201 |
-
|
| 202 |
return warnings
|
| 203 |
|
| 204 |
|
|
@@ -1112,6 +1112,23 @@ class TinyFluxDeep(nn.Module):
|
|
| 1112 |
if expert_features is not None and lune_features is None:
|
| 1113 |
lune_features = expert_features
|
| 1114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1115 |
# Input projections
|
| 1116 |
img = self.img_in(hidden_states)
|
| 1117 |
txt = self.txt_in(encoder_hidden_states)
|
|
|
|
| 41 |
class TinyFluxConfig:
|
| 42 |
"""
|
| 43 |
Configuration for TinyFlux-Deep v4.1 model.
|
| 44 |
+
|
| 45 |
This config fully defines the model architecture and can be used to:
|
| 46 |
1. Initialize a new model
|
| 47 |
+
2. Convert checkpoints between versions
|
| 48 |
3. Validate checkpoint compatibility
|
| 49 |
+
|
| 50 |
All dimension constraints are validated on creation.
|
| 51 |
"""
|
| 52 |
|
|
|
|
| 105 |
f"num_attention_heads * attention_head_dim ({expected_hidden})"
|
| 106 |
)
|
| 107 |
|
| 108 |
+
# Validate RoPE dimensions
|
| 109 |
if isinstance(self.axes_dims_rope, list):
|
| 110 |
self.axes_dims_rope = tuple(self.axes_dims_rope)
|
| 111 |
+
|
| 112 |
rope_sum = sum(self.axes_dims_rope)
|
| 113 |
if rope_sum != self.attention_head_dim:
|
| 114 |
raise ValueError(
|
|
|
|
| 158 |
def validate_checkpoint(self, state_dict: Dict[str, torch.Tensor]) -> List[str]:
|
| 159 |
"""
|
| 160 |
Validate that a checkpoint matches this config.
|
| 161 |
+
|
| 162 |
Returns list of warnings (empty if perfect match).
|
| 163 |
"""
|
| 164 |
warnings = []
|
| 165 |
+
|
| 166 |
# Check double block count
|
| 167 |
max_double = 0
|
| 168 |
for key in state_dict:
|
|
|
|
| 171 |
max_double = max(max_double, idx + 1)
|
| 172 |
if max_double != self.num_double_layers:
|
| 173 |
warnings.append(f"double_blocks: checkpoint has {max_double}, config expects {self.num_double_layers}")
|
| 174 |
+
|
| 175 |
# Check single block count
|
| 176 |
max_single = 0
|
| 177 |
for key in state_dict:
|
|
|
|
| 180 |
max_single = max(max_single, idx + 1)
|
| 181 |
if max_single != self.num_single_layers:
|
| 182 |
warnings.append(f"single_blocks: checkpoint has {max_single}, config expects {self.num_single_layers}")
|
| 183 |
+
|
| 184 |
# Check hidden size from a known weight
|
| 185 |
if "img_in.weight" in state_dict:
|
| 186 |
w = state_dict["img_in.weight"]
|
| 187 |
if w.shape[0] != self.hidden_size:
|
| 188 |
warnings.append(f"hidden_size: checkpoint has {w.shape[0]}, config expects {self.hidden_size}")
|
| 189 |
+
|
| 190 |
# Check for v4.1 components
|
| 191 |
has_sol = any(k.startswith("sol_prior.") for k in state_dict)
|
| 192 |
has_t5 = any(k.startswith("t5_pool.") for k in state_dict)
|
| 193 |
has_lune = any(k.startswith("lune_predictor.") for k in state_dict)
|
| 194 |
+
|
| 195 |
if self.use_sol_prior and not has_sol:
|
| 196 |
warnings.append("config expects sol_prior but checkpoint missing it")
|
| 197 |
if self.use_t5_vec and not has_t5:
|
| 198 |
warnings.append("config expects t5_pool but checkpoint missing it")
|
| 199 |
if self.use_lune_expert and not has_lune:
|
| 200 |
warnings.append("config expects lune_predictor but checkpoint missing it")
|
| 201 |
+
|
| 202 |
return warnings
|
| 203 |
|
| 204 |
|
|
|
|
| 1112 |
if expert_features is not None and lune_features is None:
|
| 1113 |
lune_features = expert_features
|
| 1114 |
|
| 1115 |
+
# Ensure consistent dtype (text encoders often output float32)
|
| 1116 |
+
model_dtype = self.img_in.weight.dtype
|
| 1117 |
+
hidden_states = hidden_states.to(dtype=model_dtype)
|
| 1118 |
+
encoder_hidden_states = encoder_hidden_states.to(dtype=model_dtype)
|
| 1119 |
+
pooled_projections = pooled_projections.to(dtype=model_dtype)
|
| 1120 |
+
timestep = timestep.to(dtype=model_dtype)
|
| 1121 |
+
|
| 1122 |
+
# Cast optional expert inputs if provided
|
| 1123 |
+
if lune_features is not None:
|
| 1124 |
+
lune_features = lune_features.to(dtype=model_dtype)
|
| 1125 |
+
if sol_stats is not None:
|
| 1126 |
+
sol_stats = sol_stats.to(dtype=model_dtype)
|
| 1127 |
+
if sol_spatial is not None:
|
| 1128 |
+
sol_spatial = sol_spatial.to(dtype=model_dtype)
|
| 1129 |
+
if guidance is not None:
|
| 1130 |
+
guidance = guidance.to(dtype=model_dtype)
|
| 1131 |
+
|
| 1132 |
# Input projections
|
| 1133 |
img = self.img_in(hidden_states)
|
| 1134 |
txt = self.txt_in(encoder_hidden_states)
|