kelseye's picture
Upload folder using huggingface_hub
6322336 verified
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer, SiglipVisionConfig
from transformers import SiglipImageProcessor
from PIL import Image
import torch
def merge_lora_weight(tensors_A, tensors_B):
lora_A = torch.concat(tensors_A, dim=0)
lora_B = torch.concat(tensors_B, dim=1)
return lora_A, lora_B
def merge_lora(loras, alpha=1):
lora_merged = {}
keys = [i for i in loras[0].keys() if ".lora_A." in i]
for key in keys:
tensors_A = [lora[key] for lora in loras]
tensors_B = [lora[key.replace(".lora_A.", ".lora_B.")] for lora in loras]
lora_A, lora_B = merge_lora_weight(tensors_A, tensors_B)
lora_merged[key] = lora_A * alpha
lora_merged[key.replace(".lora_A.", ".lora_B.")] = lora_B
return lora_merged
class Siglip2ImageEncoder(SiglipVisionTransformer):
def __init__(self):
config = SiglipVisionConfig(
attention_dropout = 0.0,
dtype = "float32",
hidden_act = "gelu_pytorch_tanh",
hidden_size = 1536,
image_size = 384,
intermediate_size = 6144,
layer_norm_eps = 1e-06,
model_type = "siglip_vision_model",
num_attention_heads = 16,
num_channels = 3,
num_hidden_layers = 40,
patch_size = 16,
transformers_version = "4.56.1",
_attn_implementation = "sdpa"
)
# For compatibility with transformers
import sys
sys.modules["template_model"] = None
super().__init__(config)
self.processor = SiglipImageProcessor(
do_convert_rgb = None,
do_normalize = True,
do_rescale = True,
do_resize = True,
image_mean = [0.5, 0.5, 0.5],
image_processor_type = "SiglipImageProcessor",
image_std = [0.5, 0.5, 0.5],
processor_class = "SiglipProcessor",
resample = 2,
rescale_factor = 0.00392156862745098,
size = {
"height": 384,
"width": 384
}
)
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda", query_embs=None):
pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"]
pixel_values = pixel_values.to(device=device, dtype=torch_dtype)
output_attentions = False
output_hidden_states = False
interpolate_pos_encoding = False
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.post_layernorm(last_hidden_state)
if query_embs is None:
pooler_output = self.head(last_hidden_state)
else:
hidden_state = self.head.attention(query_embs, last_hidden_state, last_hidden_state)[0]
residual = hidden_state
hidden_state = self.head.layernorm(hidden_state)
pooler_output = residual + self.head.mlp(hidden_state)
return pooler_output
class CompressedMLP(torch.nn.Module):
def __init__(self, in_dim, mid_dim, out_dim, bias=False):
super().__init__()
self.proj_in = torch.nn.Linear(in_dim, mid_dim, bias=bias)
self.proj_out = torch.nn.Linear(mid_dim, out_dim, bias=bias)
def forward(self, x):
x = self.proj_in(x)
x = self.proj_out(x)
return x
class ImageEmbeddingToLoraMatrix(torch.nn.Module):
def __init__(self, in_dim, compress_dim, lora_a_dim, lora_b_dim, rank):
super().__init__()
self.proj_a = CompressedMLP(in_dim, compress_dim, lora_a_dim * rank)
self.proj_b = CompressedMLP(in_dim, compress_dim, lora_b_dim * rank)
self.lora_a_dim = lora_a_dim
self.lora_b_dim = lora_b_dim
self.rank = rank
def forward(self, x):
lora_a = self.proj_a(x).view(self.rank, self.lora_a_dim)
lora_b = self.proj_b(x).view(self.lora_b_dim, self.rank)
return lora_a, lora_b
class FLUX2Image2LoRAQuerys(torch.nn.Module):
def __init__(self, length, dim):
super().__init__()
self.weights = torch.nn.Parameter(torch.randn((1, length, dim)))
def forward(self):
return self.weights
class FLUX2Image2LoRAModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.lora_patterns = [
{
"name": "single_transformer_blocks.{block_id}.attn.to_qkv_mlp_proj",
"num_blocks": 20,
"dim_in": 3072,
"dim_out": 27648,
},
{
"name": "single_transformer_blocks.{block_id}.attn.to_out",
"num_blocks": 20,
"dim_in": 12288,
"dim_out": 3072,
},
]
self.image_encoder = Siglip2ImageEncoder()
self.parse_lora_layers(
self.lora_patterns,
dim_image=1536,
compress_dim=256,
rank=4,
)
self.query_embs = FLUX2Image2LoRAQuerys(len(self.layers), 1536)
def parse_lora_layers(self, lora_patterns, dim_image, compress_dim, rank):
names = []
layers = []
for lora_pattern in lora_patterns:
for block_id in range(lora_pattern["num_blocks"]):
name = lora_pattern["name"].format(block_id=block_id)
layer = ImageEmbeddingToLoraMatrix(dim_image, compress_dim, lora_pattern["dim_in"], lora_pattern["dim_out"], rank)
names.append(name)
layers.append(layer)
self.names = names
self.layers = torch.nn.ModuleList(layers)
@torch.no_grad()
def process_inputs(self, image, scale=1, **kwargs):
return {"image": image, "scale": scale}
def forward_single_image(self, image):
embs = self.image_encoder(image, query_embs=self.query_embs.weights, device=self.query_embs.weights.device)
embs = embs.chunk(len(self.layers), dim=1)
lora = {}
for emb, name, layer in zip(embs, self.names, self.layers):
lora_a, lora_b = layer(emb)
lora[f"{name}.lora_A.default.weight"] = lora_a
lora[f"{name}.lora_B.default.weight"] = lora_b
return {"lora": lora}
def forward(self, image, scale=1, **kwargs):
if not isinstance(image, list):
image = [image]
loras = [self.forward_single_image(i)["lora"] for i in image]
lora = merge_lora(loras, alpha=1 / len(loras) * scale)
return {"lora": lora}
class DataAnnotator:
def __init__(self):
from diffsynth.core import UnifiedDataset
self.image_oparator = UnifiedDataset.default_image_operator(
base_path="", # If your dataset contains relative paths, please specify the root path here.
max_pixels=1024*1024,
height_division_factor=16,
width_division_factor=16,
)
def __call__(self, image, **kwargs):
image = self.image_oparator(image)
return {"image": image}
def initialize_model_weights():
from diffsynth import ModelConfig, load_state_dict
from safetensors.torch import save_file
import os
config = ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors")
config.download_if_necessary()
state_dict = load_state_dict(config.path, torch_dtype=torch.bfloat16, device="cuda")
model = FLUX2Image2LoRAModel().to(dtype=torch.bfloat16, device="cuda")
model.image_encoder.load_state_dict(state_dict)
query_embs = {"weights": torch.concat([state_dict["head.probe"]] * len(model.layers), dim=1)}
model.query_embs.load_state_dict(query_embs, strict=False)
lora_weights = {}
for name, param in model.named_parameters():
if ".proj_b.proj_out." in name:
lora_weights[name] = param * 0
elif ".proj_b." in name or ".proj_a." in name:
lora_weights[name] = param * 0.3
model.load_state_dict(lora_weights, strict=False)
print(sum(p.numel() for p in model.parameters()))
save_file(model.state_dict(), os.path.join(os.path.dirname(__file__), "model.safetensors"))
TEMPLATE_MODEL = FLUX2Image2LoRAModel
TEMPLATE_MODEL_PATH = "model.safetensors"
TEMPLATE_DATA_PROCESSOR = DataAnnotator