from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig from transformers import SiglipImageProcessor from PIL import Image import torch def merge_lora_weight(tensors_A, tensors_B): lora_A = torch.concat(tensors_A, dim=0) lora_B = torch.concat(tensors_B, dim=1) return lora_A, lora_B def merge_lora(loras, alpha=1): lora_merged = {} keys = [i for i in loras[0].keys() if ".lora_A." in i] for key in keys: tensors_A = [lora[key] for lora in loras] tensors_B = [lora[key.replace(".lora_A.", ".lora_B.")] for lora in loras] lora_A, lora_B = merge_lora_weight(tensors_A, tensors_B) lora_merged[key] = lora_A * alpha lora_merged[key.replace(".lora_A.", ".lora_B.")] = lora_B return lora_merged class Siglip2ImageEncoder(SiglipVisionTransformer): def __init__(self): config = SiglipVisionConfig( attention_dropout = 0.0, dtype = "float32", hidden_act = "gelu_pytorch_tanh", hidden_size = 1536, image_size = 384, intermediate_size = 6144, layer_norm_eps = 1e-06, model_type = "siglip_vision_model", num_attention_heads = 16, num_channels = 3, num_hidden_layers = 40, patch_size = 16, transformers_version = "4.56.1", _attn_implementation = "sdpa" ) # For compatibility with transformers import sys sys.modules["template_model"] = None super().__init__(config) self.processor = SiglipImageProcessor( do_convert_rgb = None, do_normalize = True, do_rescale = True, do_resize = True, image_mean = [0.5, 0.5, 0.5], image_processor_type = "SiglipImageProcessor", image_std = [0.5, 0.5, 0.5], processor_class = "SiglipProcessor", resample = 2, rescale_factor = 0.00392156862745098, size = { "height": 384, "width": 384 } ) def forward(self, image, torch_dtype=torch.bfloat16, device="cuda", query_embs=None): pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"] pixel_values = pixel_values.to(device=device, dtype=torch_dtype) output_attentions = False output_hidden_states = False interpolate_pos_encoding = False hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) encoder_outputs = self.encoder( inputs_embeds=hidden_states, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.post_layernorm(last_hidden_state) if query_embs is None: pooler_output = self.head(last_hidden_state) else: hidden_state = self.head.attention(query_embs, last_hidden_state, last_hidden_state)[0] residual = hidden_state hidden_state = self.head.layernorm(hidden_state) pooler_output = residual + self.head.mlp(hidden_state) return pooler_output class CompressedMLP(torch.nn.Module): def __init__(self, in_dim, mid_dim, out_dim, bias=False): super().__init__() self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias) self.proj_out = torch.nn.Linear(mid_dim, out_dim, bias=bias) def forward(self, x): x = self.proj_in(x) x = self.proj_out(x) return x class ImageEmbeddingToLoraMatrix(torch.nn.Module): def __init__(self, in_dim, compress_dim, lora_a_dim, lora_b_dim, rank): super().__init__() self.proj_a = CompressedMLP(in_dim, compress_dim, lora_a_dim * rank) self.proj_b = CompressedMLP(in_dim, compress_dim, lora_b_dim * rank) self.lora_a_dim = lora_a_dim self.lora_b_dim = lora_b_dim self.rank = rank def forward(self, x): lora_a = self.proj_a(x).view(self.rank, self.lora_a_dim) lora_b = self.proj_b(x).view(self.lora_b_dim, self.rank) return lora_a, lora_b class FLUX2Image2LoRAQuerys(torch.nn.Module): def __init__(self, length, dim): super().__init__() self.weights = torch.nn.Parameter(torch.randn((1, length, dim))) def forward(self): return self.weights class FLUX2Image2LoRAModel(torch.nn.Module): def __init__(self): super().__init__() self.lora_patterns = [ { "name": "single_transformer_blocks.{block_id}.attn.to_qkv_mlp_proj", "num_blocks": 20, "dim_in": 3072, "dim_out": 27648, }, { "name": "single_transformer_blocks.{block_id}.attn.to_out", "num_blocks": 20, "dim_in": 12288, "dim_out": 3072, }, ] self.image_encoder = Siglip2ImageEncoder() self.parse_lora_layers( self.lora_patterns, dim_image=1536, compress_dim=256, rank=4, ) self.query_embs = FLUX2Image2LoRAQuerys(len(self.layers), 1536) def parse_lora_layers(self, lora_patterns, dim_image, compress_dim, rank): names = [] layers = [] for lora_pattern in lora_patterns: for block_id in range(lora_pattern["num_blocks"]): name = lora_pattern["name"].format(block_id=block_id) layer = ImageEmbeddingToLoraMatrix(dim_image, compress_dim, lora_pattern["dim_in"], lora_pattern["dim_out"], rank) names.append(name) layers.append(layer) self.names = names self.layers = torch.nn.ModuleList(layers) @torch.no_grad() def process_inputs(self, image, scale=1, **kwargs): return {"image": image, "scale": scale} def forward_single_image(self, image): embs = self.image_encoder(image, query_embs=self.query_embs.weights, device=self.query_embs.weights.device) embs = embs.chunk(len(self.layers), dim=1) lora = {} for emb, name, layer in zip(embs, self.names, self.layers): lora_a, lora_b = layer(emb) lora[f"{name}.lora_A.default.weight"] = lora_a lora[f"{name}.lora_B.default.weight"] = lora_b return {"lora": lora} def forward(self, image, scale=1, **kwargs): if not isinstance(image, list): image = [image] loras = [self.forward_single_image(i)["lora"] for i in image] lora = merge_lora(loras, alpha=1 / len(loras) * scale) return {"lora": lora} class DataAnnotator: def __init__(self): from diffsynth.core import UnifiedDataset self.image_oparator = UnifiedDataset.default_image_operator( base_path="", # If your dataset contains relative paths, please specify the root path here. max_pixels=1024*1024, height_division_factor=16, width_division_factor=16, ) def __call__(self, image, **kwargs): image = self.image_oparator(image) return {"image": image} def initialize_model_weights(): from diffsynth import ModelConfig, load_state_dict from safetensors.torch import save_file import os config = ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors") config.download_if_necessary() state_dict = load_state_dict(config.path, torch_dtype=torch.bfloat16, device="cuda") model = FLUX2Image2LoRAModel().to(dtype=torch.bfloat16, device="cuda") model.image_encoder.load_state_dict(state_dict) query_embs = {"weights": torch.concat([state_dict["head.probe"]] * len(model.layers), dim=1)} model.query_embs.load_state_dict(query_embs, strict=False) lora_weights = {} for name, param in model.named_parameters(): if ".proj_b.proj_out." in name: lora_weights[name] = param * 0 elif ".proj_b." in name or ".proj_a." in name: lora_weights[name] = param * 0.3 model.load_state_dict(lora_weights, strict=False) print(sum(p.numel() for p in model.parameters())) save_file(model.state_dict(), os.path.join(os.path.dirname(__file__), "model.safetensors")) TEMPLATE_MODEL = FLUX2Image2LoRAModel TEMPLATE_MODEL_PATH = "model.safetensors" TEMPLATE_DATA_PROCESSOR = DataAnnotator