| from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig |
| from transformers import SiglipImageProcessor |
| from PIL import Image |
| import torch |
|
|
|
|
| def merge_lora_weight(tensors_A, tensors_B): |
| lora_A = torch.concat(tensors_A, dim=0) |
| lora_B = torch.concat(tensors_B, dim=1) |
| return lora_A, lora_B |
|
|
|
|
| def merge_lora(loras, alpha=1): |
| lora_merged = {} |
| keys = [i for i in loras[0].keys() if ".lora_A." in i] |
| for key in keys: |
| tensors_A = [lora[key] for lora in loras] |
| tensors_B = [lora[key.replace(".lora_A.", ".lora_B.")] for lora in loras] |
| lora_A, lora_B = merge_lora_weight(tensors_A, tensors_B) |
| lora_merged[key] = lora_A * alpha |
| lora_merged[key.replace(".lora_A.", ".lora_B.")] = lora_B |
| return lora_merged |
|
|
|
|
| class Siglip2ImageEncoder(SiglipVisionTransformer): |
| def __init__(self): |
| config = SiglipVisionConfig( |
| attention_dropout = 0.0, |
| dtype = "float32", |
| hidden_act = "gelu_pytorch_tanh", |
| hidden_size = 1536, |
| image_size = 384, |
| intermediate_size = 6144, |
| layer_norm_eps = 1e-06, |
| model_type = "siglip_vision_model", |
| num_attention_heads = 16, |
| num_channels = 3, |
| num_hidden_layers = 40, |
| patch_size = 16, |
| transformers_version = "4.56.1", |
| _attn_implementation = "sdpa" |
| ) |
| |
| import sys |
| sys.modules["template_model"] = None |
|
|
| super().__init__(config) |
| self.processor = SiglipImageProcessor( |
| do_convert_rgb = None, |
| do_normalize = True, |
| do_rescale = True, |
| do_resize = True, |
| image_mean = [0.5, 0.5, 0.5], |
| image_processor_type = "SiglipImageProcessor", |
| image_std = [0.5, 0.5, 0.5], |
| processor_class = "SiglipProcessor", |
| resample = 2, |
| rescale_factor = 0.00392156862745098, |
| size = { |
| "height": 384, |
| "width": 384 |
| } |
| ) |
| |
| def forward(self, image, torch_dtype=torch.bfloat16, device="cuda", query_embs=None): |
| pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"] |
| pixel_values = pixel_values.to(device=device, dtype=torch_dtype) |
| output_attentions = False |
| output_hidden_states = False |
| interpolate_pos_encoding = False |
|
|
| hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) |
|
|
| encoder_outputs = self.encoder( |
| inputs_embeds=hidden_states, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| ) |
|
|
| last_hidden_state = encoder_outputs.last_hidden_state |
| last_hidden_state = self.post_layernorm(last_hidden_state) |
|
|
| if query_embs is None: |
| pooler_output = self.head(last_hidden_state) |
| else: |
| hidden_state = self.head.attention(query_embs, last_hidden_state, last_hidden_state)[0] |
| residual = hidden_state |
| hidden_state = self.head.layernorm(hidden_state) |
| pooler_output = residual + self.head.mlp(hidden_state) |
| return pooler_output |
|
|
|
|
| class CompressedMLP(torch.nn.Module): |
| def __init__(self, in_dim, mid_dim, out_dim, bias=False): |
| super().__init__() |
| self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias) |
| self.proj_out = torch.nn.Linear(mid_dim, out_dim, bias=bias) |
| |
| def forward(self, x): |
| x = self.proj_in(x) |
| x = self.proj_out(x) |
| return x |
|
|
|
|
| class ImageEmbeddingToLoraMatrix(torch.nn.Module): |
| def __init__(self, in_dim, compress_dim, lora_a_dim, lora_b_dim, rank): |
| super().__init__() |
| self.proj_a = CompressedMLP(in_dim, compress_dim, lora_a_dim * rank) |
| self.proj_b = CompressedMLP(in_dim, compress_dim, lora_b_dim * rank) |
| self.lora_a_dim = lora_a_dim |
| self.lora_b_dim = lora_b_dim |
| self.rank = rank |
| |
| def forward(self, x): |
| lora_a = self.proj_a(x).view(self.rank, self.lora_a_dim) |
| lora_b = self.proj_b(x).view(self.lora_b_dim, self.rank) |
| return lora_a, lora_b |
|
|
|
|
| class FLUX2Image2LoRAQuerys(torch.nn.Module): |
| def __init__(self, length, dim): |
| super().__init__() |
| self.weights = torch.nn.Parameter(torch.randn((1, length, dim))) |
| |
| def forward(self): |
| return self.weights |
|
|
|
|
| class FLUX2Image2LoRAModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.lora_patterns = [ |
| { |
| "name": "single_transformer_blocks.{block_id}.attn.to_qkv_mlp_proj", |
| "num_blocks": 20, |
| "dim_in": 3072, |
| "dim_out": 27648, |
| }, |
| { |
| "name": "single_transformer_blocks.{block_id}.attn.to_out", |
| "num_blocks": 20, |
| "dim_in": 12288, |
| "dim_out": 3072, |
| }, |
| ] |
| self.image_encoder = Siglip2ImageEncoder() |
| self.parse_lora_layers( |
| self.lora_patterns, |
| dim_image=1536, |
| compress_dim=256, |
| rank=4, |
| ) |
| self.query_embs = FLUX2Image2LoRAQuerys(len(self.layers), 1536) |
| |
| def parse_lora_layers(self, lora_patterns, dim_image, compress_dim, rank): |
| names = [] |
| layers = [] |
| for lora_pattern in lora_patterns: |
| for block_id in range(lora_pattern["num_blocks"]): |
| name = lora_pattern["name"].format(block_id=block_id) |
| layer = ImageEmbeddingToLoraMatrix(dim_image, compress_dim, lora_pattern["dim_in"], lora_pattern["dim_out"], rank) |
| names.append(name) |
| layers.append(layer) |
| self.names = names |
| self.layers = torch.nn.ModuleList(layers) |
|
|
| @torch.no_grad() |
| def process_inputs(self, image, scale=1, **kwargs): |
| return {"image": image, "scale": scale} |
| |
| def forward_single_image(self, image): |
| embs = self.image_encoder(image, query_embs=self.query_embs.weights, device=self.query_embs.weights.device) |
| embs = embs.chunk(len(self.layers), dim=1) |
| lora = {} |
| for emb, name, layer in zip(embs, self.names, self.layers): |
| lora_a, lora_b = layer(emb) |
| lora[f"{name}.lora_A.default.weight"] = lora_a |
| lora[f"{name}.lora_B.default.weight"] = lora_b |
| return {"lora": lora} |
|
|
| def forward(self, image, scale=1, **kwargs): |
| if not isinstance(image, list): |
| image = [image] |
| loras = [self.forward_single_image(i)["lora"] for i in image] |
| lora = merge_lora(loras, alpha=1 / len(loras) * scale) |
| return {"lora": lora} |
|
|
|
|
| class DataAnnotator: |
| def __init__(self): |
| from diffsynth.core import UnifiedDataset |
| self.image_oparator = UnifiedDataset.default_image_operator( |
| base_path="", |
| max_pixels=1024*1024, |
| height_division_factor=16, |
| width_division_factor=16, |
| ) |
|
|
| def __call__(self, image, **kwargs): |
| image = self.image_oparator(image) |
| return {"image": image} |
|
|
|
|
| def initialize_model_weights(): |
| from diffsynth import ModelConfig, load_state_dict |
| from safetensors.torch import save_file |
| import os |
| config = ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors") |
| config.download_if_necessary() |
| state_dict = load_state_dict(config.path, torch_dtype=torch.bfloat16, device="cuda") |
| model = FLUX2Image2LoRAModel().to(dtype=torch.bfloat16, device="cuda") |
| model.image_encoder.load_state_dict(state_dict) |
| query_embs = {"weights": torch.concat([state_dict["head.probe"]] * len(model.layers), dim=1)} |
| model.query_embs.load_state_dict(query_embs, strict=False) |
| lora_weights = {} |
| for name, param in model.named_parameters(): |
| if ".proj_b.proj_out." in name: |
| lora_weights[name] = param * 0 |
| elif ".proj_b." in name or ".proj_a." in name: |
| lora_weights[name] = param * 0.3 |
| model.load_state_dict(lora_weights, strict=False) |
| print(sum(p.numel() for p in model.parameters())) |
| save_file(model.state_dict(), os.path.join(os.path.dirname(__file__), "model.safetensors")) |
|
|
|
|
| TEMPLATE_MODEL = FLUX2Image2LoRAModel |
| TEMPLATE_MODEL_PATH = "model.safetensors" |
| TEMPLATE_DATA_PROCESSOR = DataAnnotator |
|
|