| | """ |
| | DeepGen Diffusers Pipeline - Standalone pipeline for DeepGen-1.0. |
| | |
| | This file is self-contained and does not require the DeepGen repository. |
| | It can be used with `trust_remote_code=True` when loading from HuggingFace Hub. |
| | |
| | Usage: |
| | import torch |
| | from diffusers import DiffusionPipeline |
| | pipe = DiffusionPipeline.from_pretrained( |
| | "deepgenteam/DeepGen-1.0-diffusers", |
| | torch_dtype=torch.bfloat16, |
| | trust_remote_code=True, |
| | ) |
| | pipe.to("cuda") |
| | |
| | # Text-to-Image |
| | image = pipe("a racoon holding a shiny red apple", height=512, width=512).images[0] |
| | |
| | # Image Edit |
| | from PIL import Image |
| | image = pipe("Place this guitar on a sandy beach.", |
| | image=Image.open("guitar.png"), height=512, width=512).images[0] |
| | """ |
| |
|
| | import inspect |
| | import math |
| | import os |
| | import json |
| | import warnings |
| | from functools import partial |
| | from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.utils.checkpoint |
| | from torch.nn.init import _calculate_fan_in_and_fan_out |
| | from torch.nn.utils.rnn import pad_sequence |
| |
|
| | from einops import rearrange |
| | from PIL import Image |
| | from safetensors.torch import load_file |
| |
|
| | from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler |
| | from diffusers.configuration_utils import ConfigMixin, register_to_config |
| | from diffusers.image_processor import PipelineImageInput, VaeImageProcessor |
| | from diffusers.loaders import ( |
| | FromOriginalModelMixin, |
| | FromSingleFileMixin, |
| | PeftAdapterMixin, |
| | SD3IPAdapterMixin, |
| | SD3LoraLoaderMixin, |
| | SD3Transformer2DLoadersMixin, |
| | ) |
| | from diffusers.models.attention import FeedForward, JointTransformerBlock, _chunked_feed_forward |
| | from diffusers.models.attention_processor import ( |
| | Attention, |
| | AttentionProcessor, |
| | FusedJointAttnProcessor2_0, |
| | JointAttnProcessor2_0, |
| | ) |
| | from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed |
| | from diffusers.models.modeling_outputs import Transformer2DModelOutput |
| | from diffusers.models.modeling_utils import ModelMixin |
| | from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero |
| | from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
| | from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput |
| | from diffusers.utils import ( |
| | USE_PEFT_BACKEND, |
| | is_torch_xla_available, |
| | logging, |
| | scale_lora_layers, |
| | unscale_lora_layers, |
| | ) |
| | from diffusers.utils.torch_utils import maybe_allow_in_graph, randn_tensor |
| |
|
| | from transformers import ( |
| | AutoTokenizer, |
| | CLIPTextModelWithProjection, |
| | CLIPTokenizer, |
| | Qwen2_5_VLForConditionalGeneration, |
| | SiglipImageProcessor, |
| | SiglipVisionModel, |
| | T5EncoderModel, |
| | T5TokenizerFast, |
| | ) |
| | from transformers.activations import ACT2FN |
| | from transformers.configuration_utils import PretrainedConfig |
| | from transformers.utils import ( |
| | is_flash_attn_2_available, |
| | is_flash_attn_greater_or_equal_2_10, |
| | ) |
| |
|
| | if is_flash_attn_2_available(): |
| | from transformers.modeling_flash_attention_utils import _flash_attention_forward |
| |
|
| | if is_torch_xla_available(): |
| | import torch_xla.core.xla_model as xm |
| | XLA_AVAILABLE = True |
| | else: |
| | XLA_AVAILABLE = False |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | IMAGE_MEAN = (0.48145466, 0.4578275, 0.40821073) |
| | IMAGE_STD = (0.26862954, 0.26130258, 0.27577711) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class ConnectorConfig(PretrainedConfig): |
| | def __init__( |
| | self, |
| | hidden_size=768, |
| | intermediate_size=3072, |
| | num_hidden_layers=12, |
| | num_attention_heads=12, |
| | hidden_act="gelu_pytorch_tanh", |
| | layer_norm_eps=1e-6, |
| | attention_dropout=0.0, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | self.hidden_size = hidden_size |
| | self.intermediate_size = intermediate_size |
| | self.num_hidden_layers = num_hidden_layers |
| | self.num_attention_heads = num_attention_heads |
| | self.attention_dropout = attention_dropout |
| | self.layer_norm_eps = layer_norm_eps |
| | self.hidden_act = hidden_act |
| |
|
| |
|
| | def _trunc_normal_(tensor, mean, std, a, b): |
| | def norm_cdf(x): |
| | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 |
| | if (mean < a - 2 * std) or (mean > b + 2 * std): |
| | warnings.warn( |
| | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
| | "The distribution of values may be incorrect.", stacklevel=2) |
| | l = norm_cdf((a - mean) / std) |
| | u = norm_cdf((b - mean) / std) |
| | tensor.uniform_(2 * l - 1, 2 * u - 1) |
| | tensor.erfinv_() |
| | tensor.mul_(std * math.sqrt(2.0)) |
| | tensor.add_(mean) |
| | tensor.clamp_(min=a, max=b) |
| |
|
| |
|
| | def trunc_normal_tf_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): |
| | with torch.no_grad(): |
| | _trunc_normal_(tensor, 0, 1.0, a, b) |
| | tensor.mul_(std).add_(mean) |
| |
|
| |
|
| | def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): |
| | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) |
| | denom = {"fan_in": fan_in, "fan_out": fan_out, "fan_avg": (fan_in + fan_out) / 2}[mode] |
| | variance = scale / denom |
| | if distribution == "truncated_normal": |
| | trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) |
| | elif distribution == "normal": |
| | with torch.no_grad(): |
| | tensor.normal_(std=math.sqrt(variance)) |
| | elif distribution == "uniform": |
| | bound = math.sqrt(3 * variance) |
| | with torch.no_grad(): |
| | tensor.uniform_(-bound, bound) |
| |
|
| |
|
| | def lecun_normal_(tensor): |
| | variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") |
| |
|
| |
|
| | def default_flax_embed_init(tensor): |
| | variance_scaling_(tensor, mode="fan_in", distribution="normal") |
| |
|
| |
|
| | class ConnectorAttention(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| | self.embed_dim = config.hidden_size |
| | self.num_heads = config.num_attention_heads |
| | self.head_dim = self.embed_dim // self.num_heads |
| | if self.head_dim * self.num_heads != self.embed_dim: |
| | raise ValueError( |
| | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} " |
| | f"and `num_heads`: {self.num_heads}).") |
| | self.scale = self.head_dim ** -0.5 |
| | self.dropout = config.attention_dropout |
| | self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) |
| | self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) |
| | self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) |
| | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) |
| |
|
| | def forward(self, hidden_states, attention_mask=None, output_attentions=False): |
| | batch_size, q_len, _ = hidden_states.size() |
| | query_states = self.q_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | key_states = self.k_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | value_states = self.v_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| |
|
| | k_v_seq_len = key_states.shape[-2] |
| | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale |
| | if attention_mask is not None: |
| | attn_weights = attn_weights + attention_mask |
| | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
| | attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
| | attn_output = torch.matmul(attn_weights, value_states) |
| | attn_output = attn_output.transpose(1, 2).contiguous().reshape(batch_size, q_len, self.embed_dim) |
| | attn_output = self.out_proj(attn_output) |
| | return attn_output, attn_weights |
| |
|
| |
|
| | class ConnectorFlashAttention2(ConnectorAttention): |
| | is_causal = False |
| |
|
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() |
| |
|
| | def forward(self, hidden_states, attention_mask=None, output_attentions=False): |
| | output_attentions = False |
| | batch_size, q_len, _ = hidden_states.size() |
| | query_states = self.q_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | key_states = self.k_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | value_states = self.v_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | query_states = query_states.transpose(1, 2) |
| | key_states = key_states.transpose(1, 2) |
| | value_states = value_states.transpose(1, 2) |
| | dropout_rate = self.dropout if self.training else 0.0 |
| | input_dtype = query_states.dtype |
| | if input_dtype == torch.float32: |
| | if torch.is_autocast_enabled(): |
| | target_dtype = torch.get_autocast_gpu_dtype() |
| | elif hasattr(self.config, "_pre_quantization_dtype"): |
| | target_dtype = self.config._pre_quantization_dtype |
| | else: |
| | target_dtype = self.q_proj.weight.dtype |
| | query_states = query_states.to(target_dtype) |
| | key_states = key_states.to(target_dtype) |
| | value_states = value_states.to(target_dtype) |
| | attn_output = _flash_attention_forward( |
| | query_states, key_states, value_states, attention_mask, q_len, |
| | dropout=dropout_rate, is_causal=self.is_causal, |
| | use_top_left_mask=self._flash_attn_uses_top_left_mask) |
| | attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() |
| | attn_output = self.out_proj(attn_output) |
| | return attn_output, None |
| |
|
| |
|
| | class ConnectorSdpaAttention(ConnectorAttention): |
| | is_causal = False |
| |
|
| | def forward(self, hidden_states, attention_mask=None, output_attentions=False): |
| | if output_attentions: |
| | return super().forward(hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions) |
| | batch_size, q_len, _ = hidden_states.size() |
| | query_states = self.q_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | key_states = self.k_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | value_states = self.v_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | if query_states.device.type == "cuda" and attention_mask is not None: |
| | query_states = query_states.contiguous() |
| | key_states = key_states.contiguous() |
| | value_states = value_states.contiguous() |
| | is_causal = True if self.is_causal and q_len > 1 else False |
| | attn_output = torch.nn.functional.scaled_dot_product_attention( |
| | query_states, key_states, value_states, attn_mask=attention_mask, |
| | dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal) |
| | attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.embed_dim) |
| | attn_output = self.out_proj(attn_output) |
| | return attn_output, None |
| |
|
| |
|
| | CONNECTOR_ATTENTION_CLASSES = { |
| | "eager": ConnectorAttention, |
| | "flash_attention_2": ConnectorFlashAttention2, |
| | "sdpa": ConnectorSdpaAttention, |
| | } |
| |
|
| |
|
| | class ConnectorMLP(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| | self.activation_fn = ACT2FN[config.hidden_act] |
| | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) |
| | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) |
| |
|
| | def forward(self, hidden_states): |
| | hidden_states = self.fc1(hidden_states) |
| | hidden_states = self.activation_fn(hidden_states) |
| | hidden_states = self.fc2(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | def _init_connector_weights(module): |
| | if isinstance(module, nn.Embedding): |
| | default_flax_embed_init(module.weight) |
| | elif isinstance(module, ConnectorAttention): |
| | nn.init.xavier_uniform_(module.q_proj.weight) |
| | nn.init.xavier_uniform_(module.k_proj.weight) |
| | nn.init.xavier_uniform_(module.v_proj.weight) |
| | nn.init.xavier_uniform_(module.out_proj.weight) |
| | nn.init.zeros_(module.q_proj.bias) |
| | nn.init.zeros_(module.k_proj.bias) |
| | nn.init.zeros_(module.v_proj.bias) |
| | nn.init.zeros_(module.out_proj.bias) |
| | elif isinstance(module, ConnectorMLP): |
| | nn.init.xavier_uniform_(module.fc1.weight) |
| | nn.init.xavier_uniform_(module.fc2.weight) |
| | nn.init.normal_(module.fc1.bias, std=1e-6) |
| | nn.init.normal_(module.fc2.bias, std=1e-6) |
| | elif isinstance(module, (nn.Linear, nn.Conv2d)): |
| | lecun_normal_(module.weight) |
| | if module.bias is not None: |
| | nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.LayerNorm): |
| | module.bias.data.zero_() |
| | module.weight.data.fill_(1.0) |
| |
|
| |
|
| | class ConnectorEncoderLayer(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.embed_dim = config.hidden_size |
| | self.self_attn = CONNECTOR_ATTENTION_CLASSES[config._attn_implementation](config=config) |
| | self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
| | self.mlp = ConnectorMLP(config) |
| | self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
| |
|
| | def forward(self, hidden_states, attention_mask, output_attentions=False): |
| | residual = hidden_states |
| | hidden_states = self.layer_norm1(hidden_states) |
| | hidden_states, attn_weights = self.self_attn( |
| | hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions) |
| | hidden_states = residual + hidden_states |
| | residual = hidden_states |
| | hidden_states = self.layer_norm2(hidden_states) |
| | hidden_states = self.mlp(hidden_states) |
| | hidden_states = residual + hidden_states |
| | outputs = (hidden_states,) |
| | if output_attentions: |
| | outputs += (attn_weights,) |
| | return outputs |
| |
|
| |
|
| | class ConnectorEncoder(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | self.config = config |
| | self.layers = nn.ModuleList([ConnectorEncoderLayer(config) for _ in range(config.num_hidden_layers)]) |
| | self.gradient_checkpointing = False |
| | self.apply(_init_connector_weights) |
| |
|
| | def forward(self, inputs_embeds): |
| | hidden_states = inputs_embeds |
| | for encoder_layer in self.layers: |
| | if self.gradient_checkpointing and self.training: |
| | layer_outputs = torch.utils.checkpoint.checkpoint( |
| | encoder_layer.__call__, hidden_states, None, False, use_reentrant=False) |
| | else: |
| | layer_outputs = encoder_layer(hidden_states, None, output_attentions=False) |
| | hidden_states = layer_outputs[0] |
| | return hidden_states |
| |
|
| |
|
| | class DeepGenConnector(nn.Module): |
| | """Connector module bridging VLM hidden states to DiT conditioning.""" |
| |
|
| | def __init__(self, connector_config, num_queries, llm_hidden_size, |
| | projector_1_in, projector_1_out, |
| | projector_2_in, projector_2_out, |
| | projector_3_in, projector_3_out): |
| | super().__init__() |
| | self.connector = ConnectorEncoder(ConnectorConfig(**connector_config)) |
| | self.projector_1 = nn.Linear(projector_1_in, projector_1_out) |
| | self.projector_2 = nn.Linear(projector_2_in, projector_2_out) |
| | self.projector_3 = nn.Linear(projector_3_in, projector_3_out) |
| | self.meta_queries = nn.Parameter(torch.zeros(num_queries, llm_hidden_size)) |
| | self.num_queries = num_queries |
| |
|
| | def llm2dit(self, x): |
| | x = self.connector(self.projector_1(x)) |
| | pooled_out = self.projector_2(x.mean(1)) |
| | seq_out = self.projector_3(x) |
| | return pooled_out, seq_out |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class CustomJointAttnProcessor2_0: |
| | """Attention processor supporting attention masks for dynamic-resolution SD3.""" |
| |
|
| | def __init__(self): |
| | if not hasattr(F, "scaled_dot_product_attention"): |
| | raise ImportError("CustomJointAttnProcessor2_0 requires PyTorch 2.0+") |
| |
|
| | def __call__(self, attn, hidden_states, encoder_hidden_states=None, |
| | attention_mask=None, *args, **kwargs): |
| | residual = hidden_states |
| | batch_size = hidden_states.shape[0] |
| |
|
| | query = attn.to_q(hidden_states) |
| | key = attn.to_k(hidden_states) |
| | value = attn.to_v(hidden_states) |
| |
|
| | inner_dim = key.shape[-1] |
| | head_dim = inner_dim // attn.heads |
| |
|
| | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| |
|
| | if attn.norm_q is not None: |
| | query = attn.norm_q(query) |
| | if attn.norm_k is not None: |
| | key = attn.norm_k(key) |
| |
|
| | if encoder_hidden_states is not None: |
| | ctx_len = encoder_hidden_states.shape[1] |
| | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| |
|
| | if attn.norm_added_q is not None: |
| | encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) |
| | if attn.norm_added_k is not None: |
| | encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) |
| |
|
| | query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) |
| | key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) |
| | value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) |
| |
|
| | if attention_mask is not None: |
| | encoder_attention_mask = torch.ones( |
| | batch_size, ctx_len, dtype=torch.bool, device=hidden_states.device) |
| | attention_mask = torch.cat([attention_mask, encoder_attention_mask], dim=1) |
| |
|
| | if attention_mask is not None: |
| | attention_mask = attention_mask[:, None] * attention_mask[..., None] |
| | indices = range(attention_mask.shape[1]) |
| | attention_mask[:, indices, indices] = True |
| | attention_mask = attention_mask[:, None] |
| |
|
| | hidden_states = F.scaled_dot_product_attention( |
| | query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask) |
| | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
| | hidden_states = hidden_states.to(query.dtype) |
| |
|
| | if encoder_hidden_states is not None: |
| | hidden_states, encoder_hidden_states = ( |
| | hidden_states[:, :residual.shape[1]], |
| | hidden_states[:, residual.shape[1]:]) |
| | if not attn.context_pre_only: |
| | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) |
| |
|
| | hidden_states = attn.to_out[0](hidden_states) |
| | hidden_states = attn.to_out[1](hidden_states) |
| |
|
| | if encoder_hidden_states is not None: |
| | return hidden_states, encoder_hidden_states |
| | else: |
| | return hidden_states |
| |
|
| |
|
| | class CustomJointTransformerBlock(JointTransformerBlock): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.attn.set_processor(CustomJointAttnProcessor2_0()) |
| | if self.attn2 is not None: |
| | self.attn2.set_processor(CustomJointAttnProcessor2_0()) |
| |
|
| | def forward(self, hidden_states, encoder_hidden_states, temb, |
| | attention_mask=None, joint_attention_kwargs=None): |
| | joint_attention_kwargs = joint_attention_kwargs or {} |
| | if self.use_dual_attention: |
| | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(hidden_states, emb=temb) |
| | else: |
| | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) |
| |
|
| | if self.context_pre_only: |
| | norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) |
| | else: |
| | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb) |
| |
|
| | attn_output, context_attn_output = self.attn( |
| | hidden_states=norm_hidden_states, attention_mask=attention_mask, |
| | encoder_hidden_states=norm_encoder_hidden_states, **joint_attention_kwargs) |
| |
|
| | attn_output = gate_msa.unsqueeze(1) * attn_output |
| | hidden_states = hidden_states + attn_output |
| |
|
| | if self.use_dual_attention: |
| | attn_output2 = self.attn2(hidden_states=norm_hidden_states2, attention_mask=attention_mask, **joint_attention_kwargs) |
| | attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 |
| | hidden_states = hidden_states + attn_output2 |
| |
|
| | norm_hidden_states = self.norm2(hidden_states) |
| | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
| | if self._chunk_size is not None: |
| | ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) |
| | else: |
| | ff_output = self.ff(norm_hidden_states) |
| | ff_output = gate_mlp.unsqueeze(1) * ff_output |
| | hidden_states = hidden_states + ff_output |
| |
|
| | if self.context_pre_only: |
| | encoder_hidden_states = None |
| | else: |
| | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output |
| | encoder_hidden_states = encoder_hidden_states + context_attn_output |
| | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) |
| | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] |
| | if self._chunk_size is not None: |
| | context_ff_output = _chunked_feed_forward(self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size) |
| | else: |
| | context_ff_output = self.ff_context(norm_encoder_hidden_states) |
| | encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output |
| |
|
| | return encoder_hidden_states, hidden_states |
| |
|
| |
|
| | class SD3Transformer2DModel( |
| | ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin |
| | ): |
| | _supports_gradient_checkpointing = True |
| | _no_split_modules = ["JointTransformerBlock", "CustomJointTransformerBlock"] |
| | _skip_layerwise_casting_patterns = ["pos_embed", "norm"] |
| |
|
| | @register_to_config |
| | def __init__( |
| | self, |
| | sample_size: int = 128, |
| | patch_size: int = 2, |
| | in_channels: int = 16, |
| | num_layers: int = 18, |
| | attention_head_dim: int = 64, |
| | num_attention_heads: int = 18, |
| | joint_attention_dim: int = 4096, |
| | caption_projection_dim: int = 1152, |
| | pooled_projection_dim: int = 2048, |
| | out_channels: int = 16, |
| | pos_embed_max_size: int = 96, |
| | dual_attention_layers: Tuple[int, ...] = (), |
| | qk_norm: Optional[str] = None, |
| | ): |
| | super().__init__() |
| | self.out_channels = out_channels if out_channels is not None else in_channels |
| | self.inner_dim = num_attention_heads * attention_head_dim |
| |
|
| | self.pos_embed = PatchEmbed( |
| | height=sample_size, width=sample_size, patch_size=patch_size, |
| | in_channels=in_channels, embed_dim=self.inner_dim, |
| | pos_embed_max_size=pos_embed_max_size) |
| | self.time_text_embed = CombinedTimestepTextProjEmbeddings( |
| | embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim) |
| | self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) |
| |
|
| | self.transformer_blocks = nn.ModuleList([ |
| | CustomJointTransformerBlock( |
| | dim=self.inner_dim, |
| | num_attention_heads=num_attention_heads, |
| | attention_head_dim=attention_head_dim, |
| | context_pre_only=i == num_layers - 1, |
| | qk_norm=qk_norm, |
| | use_dual_attention=True if i in dual_attention_layers else False, |
| | ) for i in range(num_layers) |
| | ]) |
| |
|
| | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) |
| | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) |
| | self.gradient_checkpointing = False |
| |
|
| | @property |
| | def attn_processors(self): |
| | processors = {} |
| | def fn_recursive_add_processors(name, module, processors): |
| | if hasattr(module, "get_processor"): |
| | processors[f"{name}.processor"] = module.get_processor() |
| | for sub_name, child in module.named_children(): |
| | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) |
| | return processors |
| | for name, module in self.named_children(): |
| | fn_recursive_add_processors(name, module, processors) |
| | return processors |
| |
|
| | def set_attn_processor(self, processor): |
| | count = len(self.attn_processors.keys()) |
| | if isinstance(processor, dict) and len(processor) != count: |
| | raise ValueError(f"A dict of processors was passed, but the number of processors {len(processor)} does not match the number of attention layers: {count}.") |
| | def fn_recursive_attn_processor(name, module, processor): |
| | if hasattr(module, "set_processor"): |
| | if not isinstance(processor, dict): |
| | module.set_processor(processor) |
| | else: |
| | module.set_processor(processor.pop(f"{name}.processor")) |
| | for sub_name, child in module.named_children(): |
| | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) |
| | for name, module in self.named_children(): |
| | fn_recursive_attn_processor(name, module, processor) |
| |
|
| | def forward( |
| | self, |
| | hidden_states, |
| | encoder_hidden_states=None, |
| | cond_hidden_states=None, |
| | pooled_projections=None, |
| | timestep=None, |
| | block_controlnet_hidden_states=None, |
| | joint_attention_kwargs=None, |
| | return_dict=True, |
| | skip_layers=None, |
| | ): |
| | if joint_attention_kwargs is not None: |
| | joint_attention_kwargs = joint_attention_kwargs.copy() |
| | lora_scale = joint_attention_kwargs.pop("scale", 1.0) |
| | else: |
| | lora_scale = 1.0 |
| |
|
| | if USE_PEFT_BACKEND: |
| | scale_lora_layers(self, lora_scale) |
| | else: |
| | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: |
| | logger.warning("Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective.") |
| |
|
| | latent_sizes = [hs.shape[-2:] for hs in hidden_states] |
| | bsz = len(hidden_states) |
| |
|
| | hidden_states_list = [] |
| | for idx in range(bsz): |
| | hidden_states_per_sample = self.pos_embed(hidden_states[idx][None])[0] |
| | if cond_hidden_states is not None: |
| | for ref in cond_hidden_states[idx]: |
| | hidden_states_per_sample = torch.cat( |
| | [hidden_states_per_sample, self.pos_embed(ref[None])[0]]) |
| | hidden_states_list.append(hidden_states_per_sample) |
| |
|
| | max_len = max([len(hs) for hs in hidden_states_list]) |
| | attention_mask = torch.zeros(bsz, max_len, dtype=torch.bool, device=self.device) |
| | for i, hs in enumerate(hidden_states_list): |
| | attention_mask[i, :len(hs)] = True |
| |
|
| | hidden_states = pad_sequence(hidden_states_list, batch_first=True, padding_value=0.0, padding_side='right') |
| |
|
| | temb = self.time_text_embed(timestep, pooled_projections) |
| | encoder_hidden_states = self.context_embedder(encoder_hidden_states) |
| |
|
| | if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: |
| | ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") |
| | ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep) |
| | joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb) |
| |
|
| | for index_block, block in enumerate(self.transformer_blocks): |
| | is_skip = True if skip_layers is not None and index_block in skip_layers else False |
| | if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip: |
| | encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( |
| | block, hidden_states, encoder_hidden_states, temb, attention_mask, joint_attention_kwargs) |
| | elif not is_skip: |
| | encoder_hidden_states, hidden_states = block( |
| | hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, |
| | temb=temb, attention_mask=attention_mask, joint_attention_kwargs=joint_attention_kwargs) |
| |
|
| | if block_controlnet_hidden_states is not None and block.context_pre_only is False: |
| | interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states) |
| | hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)] |
| |
|
| | hidden_states = self.norm_out(hidden_states, temb) |
| | hidden_states = self.proj_out(hidden_states) |
| |
|
| | patch_size = self.config.patch_size |
| | latent_sizes = [(ls[0] // patch_size, ls[1] // patch_size) for ls in latent_sizes] |
| |
|
| | output = [rearrange(hs[:math.prod(latent_size)], '(h w) (p q c) -> c (h p) (w q)', |
| | h=latent_size[0], w=latent_size[1], p=patch_size, q=patch_size) |
| | for hs, latent_size in zip(hidden_states, latent_sizes)] |
| |
|
| | try: |
| | output = torch.stack(output) |
| | except: |
| | pass |
| |
|
| | if USE_PEFT_BACKEND: |
| | unscale_lora_layers(self, lora_scale) |
| |
|
| | if not return_dict: |
| | return (output,) |
| | return Transformer2DModelOutput(sample=output) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def calculate_shift(image_seq_len, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15): |
| | m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
| | b = base_shift - m * base_seq_len |
| | mu = image_seq_len * m + b |
| | return mu |
| |
|
| |
|
| | def retrieve_timesteps(scheduler, num_inference_steps=None, device=None, timesteps=None, sigmas=None, **kwargs): |
| | if timesteps is not None and sigmas is not None: |
| | raise ValueError("Only one of `timesteps` or `sigmas` can be passed.") |
| | if timesteps is not None: |
| | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| | if not accepts_timesteps: |
| | raise ValueError(f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom timestep schedules.") |
| | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
| | timesteps = scheduler.timesteps |
| | num_inference_steps = len(timesteps) |
| | elif sigmas is not None: |
| | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| | if not accept_sigmas: |
| | raise ValueError(f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom sigmas schedules.") |
| | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
| | timesteps = scheduler.timesteps |
| | num_inference_steps = len(timesteps) |
| | else: |
| | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
| | timesteps = scheduler.timesteps |
| | return timesteps, num_inference_steps |
| |
|
| |
|
| | class _SD3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin): |
| | """Internal SD3 pipeline with cond_latents support.""" |
| |
|
| | model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" |
| | _optional_components = ["image_encoder", "feature_extractor"] |
| | _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] |
| |
|
| | def __init__(self, transformer, scheduler, vae, text_encoder, tokenizer, |
| | text_encoder_2, tokenizer_2, text_encoder_3, tokenizer_3, |
| | image_encoder=None, feature_extractor=None): |
| | super().__init__() |
| | self.register_modules( |
| | vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, |
| | text_encoder_3=text_encoder_3, tokenizer=tokenizer, tokenizer_2=tokenizer_2, |
| | tokenizer_3=tokenizer_3, transformer=transformer, scheduler=scheduler, |
| | image_encoder=image_encoder, feature_extractor=feature_extractor) |
| | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 |
| | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) |
| | self.tokenizer_max_length = self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 |
| | self.default_sample_size = self.transformer.config.sample_size if hasattr(self, "transformer") and self.transformer is not None else 128 |
| | self.patch_size = self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2 |
| |
|
| | def check_inputs(self, prompt, prompt_2, prompt_3, height, width, negative_prompt=None, |
| | negative_prompt_2=None, negative_prompt_3=None, prompt_embeds=None, |
| | negative_prompt_embeds=None, pooled_prompt_embeds=None, |
| | negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, |
| | max_sequence_length=None): |
| | if height % (self.vae_scale_factor * self.patch_size) != 0 or width % (self.vae_scale_factor * self.patch_size) != 0: |
| | raise ValueError(f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size}.") |
| | if prompt_embeds is not None and pooled_prompt_embeds is None: |
| | raise ValueError("If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed.") |
| | if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: |
| | raise ValueError("If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed.") |
| |
|
| | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): |
| | if latents is not None: |
| | return latents.to(device=device, dtype=dtype) |
| | shape = (batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor) |
| | if isinstance(generator, list) and len(generator) != batch_size: |
| | raise ValueError(f"You have passed a list of generators of length {len(generator)}, but requested an effective batch size of {batch_size}.") |
| | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| | return latents |
| |
|
| | @property |
| | def guidance_scale(self): |
| | return self._guidance_scale |
| |
|
| | @property |
| | def do_classifier_free_guidance(self): |
| | return self._guidance_scale > 1 |
| |
|
| | @property |
| | def joint_attention_kwargs(self): |
| | return self._joint_attention_kwargs |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, |
| | prompt=None, prompt_2=None, prompt_3=None, |
| | height=None, width=None, num_inference_steps=28, sigmas=None, |
| | guidance_scale=7.0, |
| | negative_prompt=None, negative_prompt_2=None, negative_prompt_3=None, |
| | num_images_per_prompt=1, generator=None, latents=None, |
| | cond_latents=None, |
| | prompt_embeds=None, negative_prompt_embeds=None, |
| | pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, |
| | output_type="pil", return_dict=True, |
| | joint_attention_kwargs=None, callback_on_step_end=None, |
| | callback_on_step_end_tensor_inputs=["latents"], |
| | max_sequence_length=256, mu=None, **kwargs, |
| | ): |
| | height = height or self.default_sample_size * self.vae_scale_factor |
| | width = width or self.default_sample_size * self.vae_scale_factor |
| |
|
| | self.check_inputs(prompt, prompt_2, prompt_3, height, width, |
| | negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, |
| | negative_prompt_embeds=negative_prompt_embeds, |
| | pooled_prompt_embeds=pooled_prompt_embeds, |
| | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds) |
| |
|
| | self._guidance_scale = guidance_scale |
| | self._joint_attention_kwargs = joint_attention_kwargs |
| | self._interrupt = False |
| |
|
| | if prompt is not None and isinstance(prompt, str): |
| | batch_size = 1 |
| | elif prompt is not None and isinstance(prompt, list): |
| | batch_size = len(prompt) |
| | else: |
| | batch_size = prompt_embeds.shape[0] |
| |
|
| | device = self._execution_device |
| |
|
| | (prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds) = ( |
| | prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds) |
| |
|
| | if self.do_classifier_free_guidance: |
| | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
| | pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) |
| |
|
| | num_channels_latents = self.transformer.config.in_channels |
| | latents = self.prepare_latents( |
| | batch_size * num_images_per_prompt, num_channels_latents, height, width, |
| | prompt_embeds.dtype, device, generator, latents) |
| |
|
| | scheduler_kwargs = {} |
| | if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: |
| | _, _, h, w = latents.shape |
| | image_seq_len = (h // self.transformer.config.patch_size) * (w // self.transformer.config.patch_size) |
| | mu = calculate_shift( |
| | image_seq_len, |
| | self.scheduler.config.get("base_image_seq_len", 256), |
| | self.scheduler.config.get("max_image_seq_len", 4096), |
| | self.scheduler.config.get("base_shift", 0.5), |
| | self.scheduler.config.get("max_shift", 1.16)) |
| | scheduler_kwargs["mu"] = mu |
| | elif mu is not None: |
| | scheduler_kwargs["mu"] = mu |
| |
|
| | timesteps, num_inference_steps = retrieve_timesteps( |
| | self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs) |
| | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) |
| |
|
| | if cond_latents is not None and self.do_classifier_free_guidance: |
| | if len(cond_latents) == latents.shape[0]: |
| | cond_latents = cond_latents * 2 |
| |
|
| | with self.progress_bar(total=num_inference_steps) as progress_bar: |
| | for i, t in enumerate(timesteps): |
| | if self._interrupt: |
| | continue |
| | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents |
| | timestep = t.expand(latent_model_input.shape[0]) |
| | noise_pred = self.transformer( |
| | hidden_states=latent_model_input, cond_hidden_states=cond_latents, |
| | timestep=timestep, encoder_hidden_states=prompt_embeds, |
| | pooled_projections=pooled_prompt_embeds, |
| | joint_attention_kwargs=self.joint_attention_kwargs, |
| | return_dict=False)[0] |
| |
|
| | if self.do_classifier_free_guidance: |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
|
| | latents_dtype = latents.dtype |
| | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] |
| | if latents.dtype != latents_dtype: |
| | if torch.backends.mps.is_available(): |
| | latents = latents.to(latents_dtype) |
| |
|
| | if callback_on_step_end is not None: |
| | callback_kwargs = {} |
| | for k in callback_on_step_end_tensor_inputs: |
| | callback_kwargs[k] = locals()[k] |
| | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
| | latents = callback_outputs.pop("latents", latents) |
| |
|
| | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
| | progress_bar.update() |
| |
|
| | if XLA_AVAILABLE: |
| | xm.mark_step() |
| |
|
| | if output_type == "latent": |
| | image = latents |
| | else: |
| | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor |
| | image = self.vae.decode(latents, return_dict=False)[0] |
| | image = self.image_processor.postprocess(image, output_type=output_type) |
| |
|
| | self.maybe_free_model_hooks() |
| |
|
| | if not return_dict: |
| | return (image,) |
| | return StableDiffusion3PipelineOutput(images=image) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class DeepGenPipeline(DiffusionPipeline): |
| | """ |
| | DeepGen 1.0 Pipeline for text-to-image generation and image editing. |
| | |
| | This pipeline integrates Qwen2.5-VL (VLM) + SCB Connector + SD3 DiT into a |
| | single interface. Standard diffusers components (transformer, vae, scheduler) |
| | are loaded by DiffusionPipeline; non-standard components (VLM, connector, |
| | tokenizer, prompt_template) are loaded automatically on first use. |
| | |
| | Usage: |
| | pipe = DiffusionPipeline.from_pretrained( |
| | "deepgenteam/DeepGen-1.0-diffusers", |
| | torch_dtype=torch.bfloat16, |
| | trust_remote_code=True, |
| | ) |
| | pipe.to("cuda") |
| | result = pipe("a raccoon holding an apple", height=512, width=512) |
| | result.images[0].save("output.png") |
| | """ |
| |
|
| | _optional_components = [] |
| |
|
| | def __init__( |
| | self, |
| | transformer: SD3Transformer2DModel, |
| | vae: AutoencoderKL, |
| | scheduler: FlowMatchEulerDiscreteScheduler, |
| | ): |
| | super().__init__() |
| | self.register_modules( |
| | transformer=transformer, |
| | vae=vae, |
| | scheduler=scheduler, |
| | ) |
| | self._upgrade_transformer() |
| | self._extras_loaded = False |
| | self._cpu_offload = False |
| | self._gpu_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| | self.lmm = None |
| | self.tokenizer = None |
| | self.connector_module = None |
| | self.prompt_template = None |
| | self.max_length = 1024 |
| | self.image_token_id = None |
| | self.vit_mean = torch.tensor(IMAGE_MEAN) |
| | self.vit_std = torch.tensor(IMAGE_STD) |
| |
|
| | def _upgrade_transformer(self): |
| | """Convert standard diffusers SD3Transformer2DModel to custom version |
| | with cond_latents support for image editing. No weight copying needed.""" |
| | from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel as _OrigSD3 |
| | if isinstance(self.transformer, _OrigSD3) and not isinstance(self.transformer, SD3Transformer2DModel): |
| | self.transformer.__class__ = SD3Transformer2DModel |
| | for block in self.transformer.transformer_blocks: |
| | block.__class__ = CustomJointTransformerBlock |
| | block.attn.set_processor(CustomJointAttnProcessor2_0()) |
| | if block.attn2 is not None: |
| | block.attn2.set_processor(CustomJointAttnProcessor2_0()) |
| |
|
| | def _resolve_pretrained_path(self): |
| | path = self.config._name_or_path |
| | if os.path.isdir(path): |
| | return path |
| | from huggingface_hub import snapshot_download |
| | return snapshot_download(repo_id=path) |
| |
|
| | def _load_extras(self, vlm_model_path=None, attn_implementation="flash_attention_2"): |
| | """Load non-standard components (VLM, connector, tokenizer, prompt_template).""" |
| | if self._extras_loaded: |
| | return |
| | path = self._resolve_pretrained_path() |
| | dtype = next(self.transformer.parameters()).dtype |
| |
|
| | model_index_path = os.path.join(path, "model_index.json") |
| | extra_cfg = {} |
| | if os.path.isfile(model_index_path): |
| | with open(model_index_path, "r") as f: |
| | extra_cfg = json.load(f) |
| |
|
| | |
| | vlm_path = vlm_model_path |
| | if vlm_path is None: |
| | local_merged = os.path.join(path, "vlm") |
| | if os.path.isdir(local_merged): |
| | vlm_path = local_merged |
| | else: |
| | vlm_path = extra_cfg.get("vlm", "Qwen/Qwen2.5-VL-3B-Instruct") |
| | if not os.path.isdir(vlm_path): |
| | local_candidate = os.path.join("/data/huggingface", vlm_path.split("/")[-1]) |
| | if os.path.isdir(local_candidate): |
| | vlm_path = local_candidate |
| | print(f"Loading VLM from {vlm_path}...") |
| | try: |
| | self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| | vlm_path, torch_dtype=dtype, attn_implementation=attn_implementation) |
| | except Exception: |
| | self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| | vlm_path, torch_dtype=dtype, attn_implementation="sdpa") |
| | self.lmm.requires_grad_(False) |
| |
|
| | print("Loading tokenizer...") |
| | tokenizer_path = os.path.join(path, "tokenizer") |
| | if os.path.isdir(tokenizer_path): |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | tokenizer_path, trust_remote_code=True, padding_side='right') |
| | else: |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | vlm_path, trust_remote_code=True, padding_side='right') |
| |
|
| | print("Loading connector...") |
| | connector_dir = os.path.join(path, "connector") |
| | with open(os.path.join(connector_dir, "config.json"), "r") as f: |
| | connector_cfg = json.load(f) |
| |
|
| | conn_cfg = connector_cfg["connector"].copy() |
| | conn_cfg["_attn_implementation"] = "sdpa" |
| |
|
| | self.connector_module = DeepGenConnector( |
| | connector_config=conn_cfg, |
| | num_queries=connector_cfg["num_queries"], |
| | llm_hidden_size=connector_cfg["llm_hidden_size"], |
| | projector_1_in=connector_cfg["projector_1_in"], |
| | projector_1_out=connector_cfg["projector_1_out"], |
| | projector_2_in=connector_cfg["projector_2_in"], |
| | projector_2_out=connector_cfg["projector_2_out"], |
| | projector_3_in=connector_cfg["projector_3_in"], |
| | projector_3_out=connector_cfg["projector_3_out"], |
| | ) |
| | connector_state = load_file(os.path.join(connector_dir, "model.safetensors")) |
| | self.connector_module.load_state_dict(connector_state, strict=True) |
| | self.connector_module = self.connector_module.to(dtype=dtype) |
| |
|
| | prompt_template_path = os.path.join(path, "prompt_template.json") |
| | with open(prompt_template_path, "r") as f: |
| | self.prompt_template = json.load(f) |
| |
|
| | self.max_length = connector_cfg.get("max_length", 1024) |
| | self.image_token_id = self.tokenizer.convert_tokens_to_ids( |
| | self.prompt_template['IMG_CONTEXT_TOKEN']) |
| |
|
| | if not self._cpu_offload: |
| | device = self._gpu_device |
| | self.lmm = self.lmm.to(device=device) |
| | self.connector_module = self.connector_module.to(device=device, dtype=dtype) |
| |
|
| | self.vit_mean = self.vit_mean.to(device=self._gpu_device) |
| | self.vit_std = self.vit_std.to(device=self._gpu_device) |
| |
|
| | self._extras_loaded = True |
| | print("All components loaded.") |
| |
|
| | @property |
| | def llm(self): |
| | return self.lmm.language_model |
| |
|
| | @property |
| | def num_queries(self): |
| | return self.connector_module.num_queries |
| |
|
| | def to(self, *args, **kwargs): |
| | result = super().to(*args, **kwargs) |
| | device = None |
| | dtype = None |
| | for a in args: |
| | if isinstance(a, torch.device): |
| | device = a |
| | elif isinstance(a, str): |
| | device = torch.device(a) |
| | elif isinstance(a, torch.dtype): |
| | dtype = a |
| | device = device or kwargs.get("device") |
| | dtype = dtype or kwargs.get("dtype") |
| |
|
| | if device is not None: |
| | self._gpu_device = device |
| | if self._extras_loaded: |
| | if device is not None: |
| | self.lmm = self.lmm.to(device=device) |
| | self.connector_module = self.connector_module.to(device=device) |
| | self.vit_mean = self.vit_mean.to(device=device) |
| | self.vit_std = self.vit_std.to(device=device) |
| | if dtype is not None: |
| | self.lmm = self.lmm.to(dtype=dtype) |
| | self.connector_module = self.connector_module.to(dtype=dtype) |
| | return result |
| |
|
| | def enable_model_cpu_offload(self, gpu_id=None, device=None): |
| | """Enable sequential CPU offload to reduce GPU memory usage (~14GB).""" |
| | self._cpu_offload = True |
| | if device is not None: |
| | self._gpu_device = torch.device(device) if isinstance(device, str) else device |
| | elif gpu_id is not None: |
| | self._gpu_device = torch.device(f"cuda:{gpu_id}") |
| | self.transformer = self.transformer.to("cpu") |
| | self.vae = self.vae.to("cpu") |
| | if self._extras_loaded: |
| | self.lmm = self.lmm.to("cpu") |
| | self.connector_module = self.connector_module.to("cpu") |
| | self.vit_mean = self.vit_mean.to(self._gpu_device) |
| | self.vit_std = self.vit_std.to(self._gpu_device) |
| | torch.cuda.empty_cache() |
| |
|
| | def _offload_to(self, module, device): |
| | module.to(device) |
| | if device == torch.device("cpu") or device == "cpu": |
| | torch.cuda.empty_cache() |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| | """ |
| | Load the full pipeline. When called directly (not via DiffusionPipeline), |
| | loads all components immediately including VLM and connector. |
| | """ |
| | vlm_model_path = kwargs.pop("vlm_model_path", None) |
| | attn_implementation = kwargs.pop("attn_implementation", "flash_attention_2") |
| |
|
| | pipe = super().from_pretrained(pretrained_model_name_or_path, **kwargs) |
| |
|
| | pipe._load_extras(vlm_model_path=vlm_model_path, |
| | attn_implementation=attn_implementation) |
| | return pipe |
| |
|
| | @torch.no_grad() |
| | def pixels_to_latents(self, x): |
| | z = self.vae.encode(x).latent_dist.sample() |
| | z = (z - self.vae.config.shift_factor) * self.vae.config.scaling_factor |
| | return z |
| |
|
| | @torch.no_grad() |
| | def latents_to_pixels(self, z): |
| | z = (z / self.vae.config.scaling_factor) + self.vae.config.shift_factor |
| | x_rec = self.vae.decode(z).sample |
| | return x_rec |
| |
|
| | def prepare_text2image_prompts(self, texts): |
| | texts = [self.prompt_template['GENERATION'].format(input=text) for text in texts] |
| | texts = [self.prompt_template['INSTRUCTION'].format(input=text) for text in texts] |
| | return self.tokenizer( |
| | texts, add_special_tokens=True, return_tensors='pt', |
| | padding=True, padding_side='left').to(self._gpu_device) |
| |
|
| | def prepare_image2image_prompts(self, texts, num_refs, ref_lens): |
| | prompts = [] |
| | cnt = 0 |
| | for text, num_ref in zip(texts, num_refs): |
| | image_tokens = '' |
| | for _ in range(num_ref): |
| | image_tokens += (self.prompt_template['IMG_START_TOKEN'] + |
| | self.prompt_template['IMG_CONTEXT_TOKEN'] * ref_lens[cnt] + |
| | self.prompt_template['IMG_END_TOKEN']) |
| | cnt += 1 |
| | prompts.append(self.prompt_template['INSTRUCTION'].format( |
| | input=f'{image_tokens}\n{text}')) |
| | return self.tokenizer( |
| | prompts, add_special_tokens=True, return_tensors='pt', |
| | padding=True, padding_side='left').to(self._gpu_device) |
| |
|
| | def prepare_forward_input(self, query_embeds, input_ids=None, |
| | image_embeds=None, image_grid_thw=None, |
| | attention_mask=None, past_key_values=None): |
| | b, l, _ = query_embeds.shape |
| | attention_mask = attention_mask.to(device=self._gpu_device, dtype=torch.bool) |
| | input_ids = torch.cat([input_ids, input_ids.new_zeros(b, l)], dim=1) |
| | attention_mask = torch.cat([attention_mask, attention_mask.new_ones(b, l)], dim=1) |
| |
|
| | position_ids, _ = self.lmm.model.get_rope_index( |
| | input_ids=input_ids, image_grid_thw=image_grid_thw, |
| | video_grid_thw=None, second_per_grid_ts=None, |
| | attention_mask=attention_mask) |
| |
|
| | if past_key_values is not None: |
| | inputs_embeds = query_embeds |
| | position_ids = position_ids[..., -l:] |
| | else: |
| | input_ids = input_ids[:, :-l] |
| | if image_embeds is None: |
| | inputs_embeds = self.llm.get_input_embeddings()(input_ids) |
| | else: |
| | inputs_embeds = torch.zeros( |
| | *input_ids.shape, self.llm.config.hidden_size, |
| | device=self._gpu_device, dtype=self.transformer.dtype) |
| | inputs_embeds[input_ids == self.image_token_id] = \ |
| | image_embeds.contiguous().view(-1, self.llm.config.hidden_size) |
| | inputs_embeds[input_ids != self.image_token_id] = \ |
| | self.llm.get_input_embeddings()(input_ids[input_ids != self.image_token_id]) |
| | inputs_embeds = torch.cat([inputs_embeds, query_embeds], dim=1) |
| |
|
| | return dict(inputs_embeds=inputs_embeds, attention_mask=attention_mask, |
| | position_ids=position_ids, past_key_values=past_key_values) |
| |
|
| | @torch.no_grad() |
| | def get_semantic_features(self, pixel_values, resize=True): |
| | pixel_values = (pixel_values + 1.0) / 2 |
| | pixel_values = pixel_values - self.vit_mean.view(1, 3, 1, 1) |
| | pixel_values = pixel_values / self.vit_std.view(1, 3, 1, 1) |
| |
|
| | if resize: |
| | pixel_values = F.interpolate(pixel_values, size=(448, 448), mode='bilinear') |
| | b, c, h, w = pixel_values.shape |
| |
|
| | patch_size = self.lmm.config.vision_config.patch_size |
| | spatial_merge_size = self.lmm.config.vision_config.spatial_merge_size |
| | temporal_patch_size = self.lmm.config.vision_config.temporal_patch_size |
| |
|
| | pixel_values = pixel_values[:, None].expand(b, temporal_patch_size, c, h, w) |
| | grid_t = 1 |
| | grid_h, grid_w = h // patch_size, w // patch_size |
| |
|
| | pixel_values = pixel_values.view( |
| | b, grid_t, temporal_patch_size, c, |
| | grid_h // spatial_merge_size, spatial_merge_size, patch_size, |
| | grid_w // spatial_merge_size, spatial_merge_size, patch_size) |
| | pixel_values = rearrange( |
| | pixel_values, 'b t tp c h m p w n q -> (b t h w m n) (c tp p q)') |
| |
|
| | image_grid_thw = torch.tensor( |
| | [(grid_t, grid_h, grid_w)] * b).to(self._gpu_device).long() |
| | image_embeds = self.lmm.visual(pixel_values, grid_thw=image_grid_thw) |
| | image_embeds = rearrange(image_embeds, '(b l) d -> b l d', b=b) |
| | return image_embeds, image_grid_thw |
| |
|
| | @torch.no_grad() |
| | def get_semantic_features_dynamic(self, pixel_values): |
| | def multi_apply(func, *args, **kwargs): |
| | pfunc = partial(func, **kwargs) if kwargs else func |
| | map_results = map(pfunc, *args) |
| | return tuple(map(list, zip(*map_results))) |
| |
|
| | pixel_values = [F.interpolate(p[None], scale_factor=28/32, mode='bilinear') |
| | for p in pixel_values] |
| | image_embeds, image_grid_thw = multi_apply( |
| | self.get_semantic_features, pixel_values, resize=False) |
| | image_embeds = [x[0] for x in image_embeds] |
| | image_grid_thw = torch.cat(image_grid_thw, dim=0) |
| | return image_embeds, image_grid_thw |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, |
| | prompt: Union[str, List[str]], |
| | image: Optional[Union[Image.Image, List[Image.Image]]] = None, |
| | negative_prompt: str = "", |
| | height: int = 512, |
| | width: int = 512, |
| | num_inference_steps: int = 50, |
| | guidance_scale: float = 4.0, |
| | seed: Optional[int] = None, |
| | num_images_per_prompt: int = 1, |
| | ): |
| | """ |
| | Generate or edit images. |
| | |
| | Args: |
| | prompt: Text prompt for generation/editing. |
| | image: Optional input image(s) for editing. If None, does text-to-image. |
| | negative_prompt: Negative prompt for CFG. |
| | height: Output image height. |
| | width: Output image width. |
| | num_inference_steps: Number of denoising steps. |
| | guidance_scale: CFG guidance scale. |
| | seed: Random seed for reproducibility. |
| | num_images_per_prompt: Number of images to generate per prompt. |
| | |
| | Returns: |
| | SimpleNamespace with .images attribute (list of PIL Images). |
| | """ |
| | from types import SimpleNamespace |
| | self._load_extras() |
| |
|
| | offload = self._cpu_offload |
| | gpu = self._gpu_device |
| |
|
| | if isinstance(prompt, str): |
| | prompt = [prompt] |
| | b = len(prompt) * num_images_per_prompt |
| | prompt = prompt * num_images_per_prompt |
| | cfg_prompt = [negative_prompt] * b |
| |
|
| | generator = None |
| | if seed is not None: |
| | generator = torch.Generator(device=gpu).manual_seed(seed) |
| |
|
| | |
| | if offload: |
| | self._offload_to(self.lmm, gpu) |
| | self._offload_to(self.connector_module, gpu) |
| |
|
| | pixel_values_src = None |
| | cond_latents = None |
| | if image is not None: |
| | if isinstance(image, Image.Image): |
| | image = [image] |
| | ref_images = [] |
| | for img in image: |
| | img = img.convert('RGB').resize((width, height)) |
| | pv = torch.from_numpy(np.array(img)).float() / 255.0 |
| | pv = 2 * pv - 1 |
| | pv = rearrange(pv, 'h w c -> c h w') |
| | ref_images.append(pv.to(dtype=self.transformer.dtype, device=gpu)) |
| |
|
| | pixel_values_src = [[img for img in ref_images]] * b |
| | num_refs = [len(ref_images)] * b |
| | image_embeds, image_grid_thw = self.get_semantic_features_dynamic( |
| | [img for ref_imgs in pixel_values_src for img in ref_imgs]) |
| | ref_lens = [len(x) for x in image_embeds] |
| |
|
| | text_inputs = self.prepare_image2image_prompts( |
| | prompt + cfg_prompt, num_refs=num_refs * 2, ref_lens=ref_lens * 2) |
| | text_inputs.update( |
| | image_embeds=torch.cat(image_embeds * 2), |
| | image_grid_thw=torch.cat([image_grid_thw] * 2)) |
| |
|
| | if offload: |
| | self._offload_to(self.vae, gpu) |
| | cond_latents = [[self.pixels_to_latents(img[None])[0] for img in ref_imgs] |
| | for ref_imgs in pixel_values_src] |
| | cond_latents = cond_latents * 2 |
| | if offload: |
| | self._offload_to(self.vae, "cpu") |
| | else: |
| | text_inputs = self.prepare_text2image_prompts(prompt + cfg_prompt) |
| |
|
| | hidden_states = self.connector_module.meta_queries[None].expand( |
| | 2 * b, self.num_queries, -1) |
| | inputs = self.prepare_forward_input(query_embeds=hidden_states, **text_inputs) |
| | output = self.llm(**inputs, return_dict=True, output_hidden_states=True) |
| |
|
| | |
| | hidden_states = output.hidden_states |
| | num_layers = len(hidden_states) - 1 |
| | selected_layers = list(range(num_layers - 1, 0, -6)) |
| | selected_hiddens = [hidden_states[i] for i in selected_layers] |
| | merged_hidden = torch.cat(selected_hiddens, dim=-1) |
| | pooled_out, seq_out = self.connector_module.llm2dit(merged_hidden) |
| |
|
| | if offload: |
| | del output, hidden_states, selected_hiddens, merged_hidden |
| | self._offload_to(self.lmm, "cpu") |
| | self._offload_to(self.connector_module, "cpu") |
| |
|
| | |
| | if offload: |
| | self._offload_to(self.transformer, gpu) |
| |
|
| | pipeline = _SD3Pipeline( |
| | transformer=self.transformer, scheduler=self.scheduler, |
| | vae=self.vae, text_encoder=None, tokenizer=None, |
| | text_encoder_2=None, tokenizer_2=None, |
| | text_encoder_3=None, tokenizer_3=None) |
| |
|
| | samples = pipeline( |
| | height=height, width=width, |
| | guidance_scale=guidance_scale, |
| | num_inference_steps=num_inference_steps, |
| | prompt_embeds=seq_out[:b], |
| | pooled_prompt_embeds=pooled_out[:b], |
| | negative_prompt_embeds=seq_out[b:], |
| | negative_pooled_prompt_embeds=pooled_out[b:], |
| | generator=generator, |
| | output_type='latent', |
| | cond_latents=cond_latents, |
| | ).images.to(self.transformer.dtype) |
| |
|
| | if offload: |
| | self._offload_to(self.transformer, "cpu") |
| |
|
| | |
| | if offload: |
| | self._offload_to(self.vae, gpu) |
| |
|
| | pixels = self.latents_to_pixels(samples) |
| |
|
| | if offload: |
| | self._offload_to(self.vae, "cpu") |
| |
|
| | images = [] |
| | for i in range(pixels.shape[0]): |
| | img = pixels[i] |
| | img = rearrange(img, 'c h w -> h w c') |
| | img = torch.clamp(127.5 * img + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() |
| | images.append(Image.fromarray(img)) |
| |
|
| | return SimpleNamespace(images=images) |
| |
|