""" DeepGen Diffusers Pipeline - Standalone pipeline for DeepGen-1.0. This file is self-contained and does not require the DeepGen repository. It can be used with `trust_remote_code=True` when loading from HuggingFace Hub. Usage: import torch from diffusers import DiffusionPipeline pipe = DiffusionPipeline.from_pretrained( "deepgenteam/DeepGen-1.0-diffusers", torch_dtype=torch.bfloat16, trust_remote_code=True, ) pipe.to("cuda") # Text-to-Image image = pipe("a racoon holding a shiny red apple", height=512, width=512).images[0] # Image Edit from PIL import Image image = pipe("Place this guitar on a sandy beach.", image=Image.open("guitar.png"), height=512, width=512).images[0] """ import inspect import math import os import json import warnings from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from torch.nn.init import _calculate_fan_in_and_fan_out from torch.nn.utils.rnn import pad_sequence from einops import rearrange from PIL import Image from safetensors.torch import load_file from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import ( FromOriginalModelMixin, FromSingleFileMixin, PeftAdapterMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin, SD3Transformer2DLoadersMixin, ) from diffusers.models.attention import FeedForward, JointTransformerBlock, _chunked_feed_forward from diffusers.models.attention_processor import ( Attention, AttentionProcessor, FusedJointAttnProcessor2_0, JointAttnProcessor2_0, ) from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput from diffusers.utils import ( USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers, ) from diffusers.utils.torch_utils import maybe_allow_in_graph, randn_tensor from transformers import ( AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, Qwen2_5_VLForConditionalGeneration, SiglipImageProcessor, SiglipVisionModel, T5EncoderModel, T5TokenizerFast, ) from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from transformers.utils import ( is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, ) if is_flash_attn_2_available(): from transformers.modeling_flash_attention_utils import _flash_attention_forward if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) IMAGE_MEAN = (0.48145466, 0.4578275, 0.40821073) IMAGE_STD = (0.26862954, 0.26130258, 0.27577711) # ============================================================================= # Connector: Config + Attention + MLP + Encoder # ============================================================================= class ConnectorConfig(PretrainedConfig): def __init__( self, hidden_size=768, intermediate_size=3072, num_hidden_layers=12, num_attention_heads=12, hidden_act="gelu_pytorch_tanh", layer_norm_eps=1e-6, attention_dropout=0.0, **kwargs, ): super().__init__(**kwargs) self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.attention_dropout = attention_dropout self.layer_norm_eps = layer_norm_eps self.hidden_act = hidden_act def _trunc_normal_(tensor, mean, std, a, b): def norm_cdf(x): return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2) l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) tensor.uniform_(2 * l - 1, 2 * u - 1) tensor.erfinv_() tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) tensor.clamp_(min=a, max=b) def trunc_normal_tf_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): with torch.no_grad(): _trunc_normal_(tensor, 0, 1.0, a, b) tensor.mul_(std).add_(mean) def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) denom = {"fan_in": fan_in, "fan_out": fan_out, "fan_avg": (fan_in + fan_out) / 2}[mode] variance = scale / denom if distribution == "truncated_normal": trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) elif distribution == "normal": with torch.no_grad(): tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) with torch.no_grad(): tensor.uniform_(-bound, bound) def lecun_normal_(tensor): variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") def default_flax_embed_init(tensor): variance_scaling_(tensor, mode="fan_in", distribution="normal") class ConnectorAttention(nn.Module): def __init__(self, config): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} " f"and `num_heads`: {self.num_heads}).") self.scale = self.head_dim ** -0.5 self.dropout = config.attention_dropout self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) def forward(self, hidden_states, attention_mask=None, output_attentions=False): batch_size, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) k_v_seq_len = key_states.shape[-2] attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous().reshape(batch_size, q_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights class ConnectorFlashAttention2(ConnectorAttention): is_causal = False def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward(self, hidden_states, attention_mask=None, output_attentions=False): output_attentions = False batch_size, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) dropout_rate = self.dropout if self.training else 0.0 input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) attn_output = _flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, is_causal=self.is_causal, use_top_left_mask=self._flash_attn_uses_top_left_mask) attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() attn_output = self.out_proj(attn_output) return attn_output, None class ConnectorSdpaAttention(ConnectorAttention): is_causal = False def forward(self, hidden_states, attention_mask=None, output_attentions=False): if output_attentions: return super().forward(hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions) batch_size, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() is_causal = True if self.is_causal and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None CONNECTOR_ATTENTION_CLASSES = { "eager": ConnectorAttention, "flash_attention_2": ConnectorFlashAttention2, "sdpa": ConnectorSdpaAttention, } class ConnectorMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states): hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states def _init_connector_weights(module): if isinstance(module, nn.Embedding): default_flax_embed_init(module.weight) elif isinstance(module, ConnectorAttention): nn.init.xavier_uniform_(module.q_proj.weight) nn.init.xavier_uniform_(module.k_proj.weight) nn.init.xavier_uniform_(module.v_proj.weight) nn.init.xavier_uniform_(module.out_proj.weight) nn.init.zeros_(module.q_proj.bias) nn.init.zeros_(module.k_proj.bias) nn.init.zeros_(module.v_proj.bias) nn.init.zeros_(module.out_proj.bias) elif isinstance(module, ConnectorMLP): nn.init.xavier_uniform_(module.fc1.weight) nn.init.xavier_uniform_(module.fc2.weight) nn.init.normal_(module.fc1.bias, std=1e-6) nn.init.normal_(module.fc2.bias, std=1e-6) elif isinstance(module, (nn.Linear, nn.Conv2d)): lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class ConnectorEncoderLayer(nn.Module): def __init__(self, config): super().__init__() self.embed_dim = config.hidden_size self.self_attn = CONNECTOR_ATTENTION_CLASSES[config._attn_implementation](config=config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = ConnectorMLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) def forward(self, hidden_states, attention_mask, output_attentions=False): residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class ConnectorEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.layers = nn.ModuleList([ConnectorEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self.apply(_init_connector_weights) def forward(self, inputs_embeds): hidden_states = inputs_embeds for encoder_layer in self.layers: if self.gradient_checkpointing and self.training: layer_outputs = torch.utils.checkpoint.checkpoint( encoder_layer.__call__, hidden_states, None, False, use_reentrant=False) else: layer_outputs = encoder_layer(hidden_states, None, output_attentions=False) hidden_states = layer_outputs[0] return hidden_states class DeepGenConnector(nn.Module): """Connector module bridging VLM hidden states to DiT conditioning.""" def __init__(self, connector_config, num_queries, llm_hidden_size, projector_1_in, projector_1_out, projector_2_in, projector_2_out, projector_3_in, projector_3_out): super().__init__() self.connector = ConnectorEncoder(ConnectorConfig(**connector_config)) self.projector_1 = nn.Linear(projector_1_in, projector_1_out) self.projector_2 = nn.Linear(projector_2_in, projector_2_out) self.projector_3 = nn.Linear(projector_3_in, projector_3_out) self.meta_queries = nn.Parameter(torch.zeros(num_queries, llm_hidden_size)) self.num_queries = num_queries def llm2dit(self, x): x = self.connector(self.projector_1(x)) pooled_out = self.projector_2(x.mean(1)) seq_out = self.projector_3(x) return pooled_out, seq_out # ============================================================================= # Custom SD3 Transformer (dynamic resolution + attention mask) # ============================================================================= class CustomJointAttnProcessor2_0: """Attention processor supporting attention masks for dynamic-resolution SD3.""" def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CustomJointAttnProcessor2_0 requires PyTorch 2.0+") def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, *args, **kwargs): residual = hidden_states batch_size = hidden_states.shape[0] query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) if encoder_hidden_states is not None: ctx_len = encoder_hidden_states.shape[1] encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) if attention_mask is not None: encoder_attention_mask = torch.ones( batch_size, ctx_len, dtype=torch.bool, device=hidden_states.device) attention_mask = torch.cat([attention_mask, encoder_attention_mask], dim=1) if attention_mask is not None: attention_mask = attention_mask[:, None] * attention_mask[..., None] indices = range(attention_mask.shape[1]) attention_mask[:, indices, indices] = True attention_mask = attention_mask[:, None] hidden_states = F.scaled_dot_product_attention( query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: hidden_states, encoder_hidden_states = ( hidden_states[:, :residual.shape[1]], hidden_states[:, residual.shape[1]:]) if not attn.context_pre_only: encoder_hidden_states = attn.to_add_out(encoder_hidden_states) hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) if encoder_hidden_states is not None: return hidden_states, encoder_hidden_states else: return hidden_states class CustomJointTransformerBlock(JointTransformerBlock): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.attn.set_processor(CustomJointAttnProcessor2_0()) if self.attn2 is not None: self.attn2.set_processor(CustomJointAttnProcessor2_0()) def forward(self, hidden_states, encoder_hidden_states, temb, attention_mask=None, joint_attention_kwargs=None): joint_attention_kwargs = joint_attention_kwargs or {} if self.use_dual_attention: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(hidden_states, emb=temb) else: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) if self.context_pre_only: norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) else: norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb) attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, attention_mask=attention_mask, encoder_hidden_states=norm_encoder_hidden_states, **joint_attention_kwargs) attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output if self.use_dual_attention: attn_output2 = self.attn2(hidden_states=norm_hidden_states2, attention_mask=attention_mask, **joint_attention_kwargs) attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 hidden_states = hidden_states + attn_output2 norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self._chunk_size is not None: ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) else: ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = hidden_states + ff_output if self.context_pre_only: encoder_hidden_states = None else: context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output encoder_hidden_states = encoder_hidden_states + context_attn_output norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] if self._chunk_size is not None: context_ff_output = _chunked_feed_forward(self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size) else: context_ff_output = self.ff_context(norm_encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output return encoder_hidden_states, hidden_states class SD3Transformer2DModel( ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin ): _supports_gradient_checkpointing = True _no_split_modules = ["JointTransformerBlock", "CustomJointTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] @register_to_config def __init__( self, sample_size: int = 128, patch_size: int = 2, in_channels: int = 16, num_layers: int = 18, attention_head_dim: int = 64, num_attention_heads: int = 18, joint_attention_dim: int = 4096, caption_projection_dim: int = 1152, pooled_projection_dim: int = 2048, out_channels: int = 16, pos_embed_max_size: int = 96, dual_attention_layers: Tuple[int, ...] = (), qk_norm: Optional[str] = None, ): super().__init__() self.out_channels = out_channels if out_channels is not None else in_channels self.inner_dim = num_attention_heads * attention_head_dim self.pos_embed = PatchEmbed( height=sample_size, width=sample_size, patch_size=patch_size, in_channels=in_channels, embed_dim=self.inner_dim, pos_embed_max_size=pos_embed_max_size) self.time_text_embed = CombinedTimestepTextProjEmbeddings( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim) self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim) self.transformer_blocks = nn.ModuleList([ CustomJointTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, context_pre_only=i == num_layers - 1, qk_norm=qk_norm, use_dual_attention=True if i in dual_attention_layers else False, ) for i in range(num_layers) ]) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False @property def attn_processors(self): processors = {} def fn_recursive_add_processors(name, module, processors): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor): count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError(f"A dict of processors was passed, but the number of processors {len(processor)} does not match the number of attention layers: {count}.") def fn_recursive_attn_processor(name, module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def forward( self, hidden_states, encoder_hidden_states=None, cond_hidden_states=None, pooled_projections=None, timestep=None, block_controlnet_hidden_states=None, joint_attention_kwargs=None, return_dict=True, skip_layers=None, ): if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: scale_lora_layers(self, lora_scale) else: if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: logger.warning("Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective.") latent_sizes = [hs.shape[-2:] for hs in hidden_states] bsz = len(hidden_states) hidden_states_list = [] for idx in range(bsz): hidden_states_per_sample = self.pos_embed(hidden_states[idx][None])[0] if cond_hidden_states is not None: for ref in cond_hidden_states[idx]: hidden_states_per_sample = torch.cat( [hidden_states_per_sample, self.pos_embed(ref[None])[0]]) hidden_states_list.append(hidden_states_per_sample) max_len = max([len(hs) for hs in hidden_states_list]) attention_mask = torch.zeros(bsz, max_len, dtype=torch.bool, device=self.device) for i, hs in enumerate(hidden_states_list): attention_mask[i, :len(hs)] = True hidden_states = pad_sequence(hidden_states_list, batch_first=True, padding_value=0.0, padding_side='right') temb = self.time_text_embed(timestep, pooled_projections) encoder_hidden_states = self.context_embedder(encoder_hidden_states) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep) joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb) for index_block, block in enumerate(self.transformer_blocks): is_skip = True if skip_layers is not None and index_block in skip_layers else False if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, temb, attention_mask, joint_attention_kwargs) elif not is_skip: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, attention_mask=attention_mask, joint_attention_kwargs=joint_attention_kwargs) if block_controlnet_hidden_states is not None and block.context_pre_only is False: interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states) hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)] hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) patch_size = self.config.patch_size latent_sizes = [(ls[0] // patch_size, ls[1] // patch_size) for ls in latent_sizes] output = [rearrange(hs[:math.prod(latent_size)], '(h w) (p q c) -> c (h p) (w q)', h=latent_size[0], w=latent_size[1], p=patch_size, q=patch_size) for hs, latent_size in zip(hidden_states, latent_sizes)] try: output = torch.stack(output) except: pass if USE_PEFT_BACKEND: unscale_lora_layers(self, lora_scale) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) # ============================================================================= # Custom StableDiffusion3Pipeline (with cond_latents + dynamic shift) # ============================================================================= def calculate_shift(image_seq_len, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu def retrieve_timesteps(scheduler, num_inference_steps=None, device=None, timesteps=None, sigmas=None, **kwargs): if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed.") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError(f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom timestep schedules.") scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError(f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom sigmas schedules.") scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class _SD3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin): """Internal SD3 pipeline with cond_latents support.""" model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] def __init__(self, transformer, scheduler, vae, text_encoder, tokenizer, text_encoder_2, tokenizer_2, text_encoder_3, tokenizer_3, image_encoder=None, feature_extractor=None): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, text_encoder_3=text_encoder_3, tokenizer=tokenizer, tokenizer_2=tokenizer_2, tokenizer_3=tokenizer_3, transformer=transformer, scheduler=scheduler, image_encoder=image_encoder, feature_extractor=feature_extractor) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 self.default_sample_size = self.transformer.config.sample_size if hasattr(self, "transformer") and self.transformer is not None else 128 self.patch_size = self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2 def check_inputs(self, prompt, prompt_2, prompt_3, height, width, negative_prompt=None, negative_prompt_2=None, negative_prompt_3=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None): if height % (self.vae_scale_factor * self.patch_size) != 0 or width % (self.vae_scale_factor * self.patch_size) != 0: raise ValueError(f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size}.") if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError("If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed.") if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: raise ValueError("If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed.") def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): if latents is not None: return latents.to(device=device, dtype=dtype) shape = (batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError(f"You have passed a list of generators of length {len(generator)}, but requested an effective batch size of {batch_size}.") latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents @property def guidance_scale(self): return self._guidance_scale @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 @property def joint_attention_kwargs(self): return self._joint_attention_kwargs @torch.no_grad() def __call__( self, prompt=None, prompt_2=None, prompt_3=None, height=None, width=None, num_inference_steps=28, sigmas=None, guidance_scale=7.0, negative_prompt=None, negative_prompt_2=None, negative_prompt_3=None, num_images_per_prompt=1, generator=None, latents=None, cond_latents=None, prompt_embeds=None, negative_prompt_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, output_type="pil", return_dict=True, joint_attention_kwargs=None, callback_on_step_end=None, callback_on_step_end_tensor_inputs=["latents"], max_sequence_length=256, mu=None, **kwargs, ): height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor self.check_inputs(prompt, prompt_2, prompt_3, height, width, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds) self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device (prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds) = ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds) if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents) scheduler_kwargs = {} if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: _, _, h, w = latents.shape image_seq_len = (h // self.transformer.config.patch_size) * (w // self.transformer.config.patch_size) mu = calculate_shift( image_seq_len, self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("max_shift", 1.16)) scheduler_kwargs["mu"] = mu elif mu is not None: scheduler_kwargs["mu"] = mu timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) if cond_latents is not None and self.do_classifier_free_guidance: if len(cond_latents) == latents.shape[0]: cond_latents = cond_latents * 2 with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self._interrupt: continue latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents timestep = t.expand(latent_model_input.shape[0]) noise_pred = self.transformer( hidden_states=latent_model_input, cond_hidden_states=cond_latents, timestep=timestep, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False)[0] if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if output_type == "latent": image = latents else: latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) self.maybe_free_model_hooks() if not return_dict: return (image,) return StableDiffusion3PipelineOutput(images=image) # ============================================================================= # DeepGen Pipeline (main entry point) # ============================================================================= class DeepGenPipeline(DiffusionPipeline): """ DeepGen 1.0 Pipeline for text-to-image generation and image editing. This pipeline integrates Qwen2.5-VL (VLM) + SCB Connector + SD3 DiT into a single interface. Standard diffusers components (transformer, vae, scheduler) are loaded by DiffusionPipeline; non-standard components (VLM, connector, tokenizer, prompt_template) are loaded automatically on first use. Usage: pipe = DiffusionPipeline.from_pretrained( "deepgenteam/DeepGen-1.0-diffusers", torch_dtype=torch.bfloat16, trust_remote_code=True, ) pipe.to("cuda") result = pipe("a raccoon holding an apple", height=512, width=512) result.images[0].save("output.png") """ _optional_components = [] def __init__( self, transformer: SD3Transformer2DModel, vae: AutoencoderKL, scheduler: FlowMatchEulerDiscreteScheduler, ): super().__init__() self.register_modules( transformer=transformer, vae=vae, scheduler=scheduler, ) self._upgrade_transformer() self._extras_loaded = False self._cpu_offload = False self._gpu_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") self.lmm = None self.tokenizer = None self.connector_module = None self.prompt_template = None self.max_length = 1024 self.image_token_id = None self.vit_mean = torch.tensor(IMAGE_MEAN) self.vit_std = torch.tensor(IMAGE_STD) def _upgrade_transformer(self): """Convert standard diffusers SD3Transformer2DModel to custom version with cond_latents support for image editing. No weight copying needed.""" from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel as _OrigSD3 if isinstance(self.transformer, _OrigSD3) and not isinstance(self.transformer, SD3Transformer2DModel): self.transformer.__class__ = SD3Transformer2DModel for block in self.transformer.transformer_blocks: block.__class__ = CustomJointTransformerBlock block.attn.set_processor(CustomJointAttnProcessor2_0()) if block.attn2 is not None: block.attn2.set_processor(CustomJointAttnProcessor2_0()) def _resolve_pretrained_path(self): path = self.config._name_or_path if os.path.isdir(path): return path from huggingface_hub import snapshot_download return snapshot_download(repo_id=path) def _load_extras(self, vlm_model_path=None, attn_implementation="flash_attention_2"): """Load non-standard components (VLM, connector, tokenizer, prompt_template).""" if self._extras_loaded: return path = self._resolve_pretrained_path() dtype = next(self.transformer.parameters()).dtype model_index_path = os.path.join(path, "model_index.json") extra_cfg = {} if os.path.isfile(model_index_path): with open(model_index_path, "r") as f: extra_cfg = json.load(f) # Resolve VLM path: prefer local merged VLM (with LoRA baked in) vlm_path = vlm_model_path if vlm_path is None: local_merged = os.path.join(path, "vlm") if os.path.isdir(local_merged): vlm_path = local_merged else: vlm_path = extra_cfg.get("vlm", "Qwen/Qwen2.5-VL-3B-Instruct") if not os.path.isdir(vlm_path): local_candidate = os.path.join("/data/huggingface", vlm_path.split("/")[-1]) if os.path.isdir(local_candidate): vlm_path = local_candidate print(f"Loading VLM from {vlm_path}...") try: self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained( vlm_path, torch_dtype=dtype, attn_implementation=attn_implementation) except Exception: self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained( vlm_path, torch_dtype=dtype, attn_implementation="sdpa") self.lmm.requires_grad_(False) print("Loading tokenizer...") tokenizer_path = os.path.join(path, "tokenizer") if os.path.isdir(tokenizer_path): self.tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, trust_remote_code=True, padding_side='right') else: self.tokenizer = AutoTokenizer.from_pretrained( vlm_path, trust_remote_code=True, padding_side='right') print("Loading connector...") connector_dir = os.path.join(path, "connector") with open(os.path.join(connector_dir, "config.json"), "r") as f: connector_cfg = json.load(f) conn_cfg = connector_cfg["connector"].copy() conn_cfg["_attn_implementation"] = "sdpa" self.connector_module = DeepGenConnector( connector_config=conn_cfg, num_queries=connector_cfg["num_queries"], llm_hidden_size=connector_cfg["llm_hidden_size"], projector_1_in=connector_cfg["projector_1_in"], projector_1_out=connector_cfg["projector_1_out"], projector_2_in=connector_cfg["projector_2_in"], projector_2_out=connector_cfg["projector_2_out"], projector_3_in=connector_cfg["projector_3_in"], projector_3_out=connector_cfg["projector_3_out"], ) connector_state = load_file(os.path.join(connector_dir, "model.safetensors")) self.connector_module.load_state_dict(connector_state, strict=True) self.connector_module = self.connector_module.to(dtype=dtype) prompt_template_path = os.path.join(path, "prompt_template.json") with open(prompt_template_path, "r") as f: self.prompt_template = json.load(f) self.max_length = connector_cfg.get("max_length", 1024) self.image_token_id = self.tokenizer.convert_tokens_to_ids( self.prompt_template['IMG_CONTEXT_TOKEN']) if not self._cpu_offload: device = self._gpu_device self.lmm = self.lmm.to(device=device) self.connector_module = self.connector_module.to(device=device, dtype=dtype) self.vit_mean = self.vit_mean.to(device=self._gpu_device) self.vit_std = self.vit_std.to(device=self._gpu_device) self._extras_loaded = True print("All components loaded.") @property def llm(self): return self.lmm.language_model @property def num_queries(self): return self.connector_module.num_queries def to(self, *args, **kwargs): result = super().to(*args, **kwargs) device = None dtype = None for a in args: if isinstance(a, torch.device): device = a elif isinstance(a, str): device = torch.device(a) elif isinstance(a, torch.dtype): dtype = a device = device or kwargs.get("device") dtype = dtype or kwargs.get("dtype") if device is not None: self._gpu_device = device if self._extras_loaded: if device is not None: self.lmm = self.lmm.to(device=device) self.connector_module = self.connector_module.to(device=device) self.vit_mean = self.vit_mean.to(device=device) self.vit_std = self.vit_std.to(device=device) if dtype is not None: self.lmm = self.lmm.to(dtype=dtype) self.connector_module = self.connector_module.to(dtype=dtype) return result def enable_model_cpu_offload(self, gpu_id=None, device=None): """Enable sequential CPU offload to reduce GPU memory usage (~14GB).""" self._cpu_offload = True if device is not None: self._gpu_device = torch.device(device) if isinstance(device, str) else device elif gpu_id is not None: self._gpu_device = torch.device(f"cuda:{gpu_id}") self.transformer = self.transformer.to("cpu") self.vae = self.vae.to("cpu") if self._extras_loaded: self.lmm = self.lmm.to("cpu") self.connector_module = self.connector_module.to("cpu") self.vit_mean = self.vit_mean.to(self._gpu_device) self.vit_std = self.vit_std.to(self._gpu_device) torch.cuda.empty_cache() def _offload_to(self, module, device): module.to(device) if device == torch.device("cpu") or device == "cpu": torch.cuda.empty_cache() @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): """ Load the full pipeline. When called directly (not via DiffusionPipeline), loads all components immediately including VLM and connector. """ vlm_model_path = kwargs.pop("vlm_model_path", None) attn_implementation = kwargs.pop("attn_implementation", "flash_attention_2") pipe = super().from_pretrained(pretrained_model_name_or_path, **kwargs) pipe._load_extras(vlm_model_path=vlm_model_path, attn_implementation=attn_implementation) return pipe @torch.no_grad() def pixels_to_latents(self, x): z = self.vae.encode(x).latent_dist.sample() z = (z - self.vae.config.shift_factor) * self.vae.config.scaling_factor return z @torch.no_grad() def latents_to_pixels(self, z): z = (z / self.vae.config.scaling_factor) + self.vae.config.shift_factor x_rec = self.vae.decode(z).sample return x_rec def prepare_text2image_prompts(self, texts): texts = [self.prompt_template['GENERATION'].format(input=text) for text in texts] texts = [self.prompt_template['INSTRUCTION'].format(input=text) for text in texts] return self.tokenizer( texts, add_special_tokens=True, return_tensors='pt', padding=True, padding_side='left').to(self._gpu_device) def prepare_image2image_prompts(self, texts, num_refs, ref_lens): prompts = [] cnt = 0 for text, num_ref in zip(texts, num_refs): image_tokens = '' for _ in range(num_ref): image_tokens += (self.prompt_template['IMG_START_TOKEN'] + self.prompt_template['IMG_CONTEXT_TOKEN'] * ref_lens[cnt] + self.prompt_template['IMG_END_TOKEN']) cnt += 1 prompts.append(self.prompt_template['INSTRUCTION'].format( input=f'{image_tokens}\n{text}')) return self.tokenizer( prompts, add_special_tokens=True, return_tensors='pt', padding=True, padding_side='left').to(self._gpu_device) def prepare_forward_input(self, query_embeds, input_ids=None, image_embeds=None, image_grid_thw=None, attention_mask=None, past_key_values=None): b, l, _ = query_embeds.shape attention_mask = attention_mask.to(device=self._gpu_device, dtype=torch.bool) input_ids = torch.cat([input_ids, input_ids.new_zeros(b, l)], dim=1) attention_mask = torch.cat([attention_mask, attention_mask.new_ones(b, l)], dim=1) position_ids, _ = self.lmm.model.get_rope_index( input_ids=input_ids, image_grid_thw=image_grid_thw, video_grid_thw=None, second_per_grid_ts=None, attention_mask=attention_mask) if past_key_values is not None: inputs_embeds = query_embeds position_ids = position_ids[..., -l:] else: input_ids = input_ids[:, :-l] if image_embeds is None: inputs_embeds = self.llm.get_input_embeddings()(input_ids) else: inputs_embeds = torch.zeros( *input_ids.shape, self.llm.config.hidden_size, device=self._gpu_device, dtype=self.transformer.dtype) inputs_embeds[input_ids == self.image_token_id] = \ image_embeds.contiguous().view(-1, self.llm.config.hidden_size) inputs_embeds[input_ids != self.image_token_id] = \ self.llm.get_input_embeddings()(input_ids[input_ids != self.image_token_id]) inputs_embeds = torch.cat([inputs_embeds, query_embeds], dim=1) return dict(inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values) @torch.no_grad() def get_semantic_features(self, pixel_values, resize=True): pixel_values = (pixel_values + 1.0) / 2 pixel_values = pixel_values - self.vit_mean.view(1, 3, 1, 1) pixel_values = pixel_values / self.vit_std.view(1, 3, 1, 1) if resize: pixel_values = F.interpolate(pixel_values, size=(448, 448), mode='bilinear') b, c, h, w = pixel_values.shape patch_size = self.lmm.config.vision_config.patch_size spatial_merge_size = self.lmm.config.vision_config.spatial_merge_size temporal_patch_size = self.lmm.config.vision_config.temporal_patch_size pixel_values = pixel_values[:, None].expand(b, temporal_patch_size, c, h, w) grid_t = 1 grid_h, grid_w = h // patch_size, w // patch_size pixel_values = pixel_values.view( b, grid_t, temporal_patch_size, c, grid_h // spatial_merge_size, spatial_merge_size, patch_size, grid_w // spatial_merge_size, spatial_merge_size, patch_size) pixel_values = rearrange( pixel_values, 'b t tp c h m p w n q -> (b t h w m n) (c tp p q)') image_grid_thw = torch.tensor( [(grid_t, grid_h, grid_w)] * b).to(self._gpu_device).long() image_embeds = self.lmm.visual(pixel_values, grid_thw=image_grid_thw) image_embeds = rearrange(image_embeds, '(b l) d -> b l d', b=b) return image_embeds, image_grid_thw @torch.no_grad() def get_semantic_features_dynamic(self, pixel_values): def multi_apply(func, *args, **kwargs): pfunc = partial(func, **kwargs) if kwargs else func map_results = map(pfunc, *args) return tuple(map(list, zip(*map_results))) pixel_values = [F.interpolate(p[None], scale_factor=28/32, mode='bilinear') for p in pixel_values] image_embeds, image_grid_thw = multi_apply( self.get_semantic_features, pixel_values, resize=False) image_embeds = [x[0] for x in image_embeds] image_grid_thw = torch.cat(image_grid_thw, dim=0) return image_embeds, image_grid_thw @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], image: Optional[Union[Image.Image, List[Image.Image]]] = None, negative_prompt: str = "", height: int = 512, width: int = 512, num_inference_steps: int = 50, guidance_scale: float = 4.0, seed: Optional[int] = None, num_images_per_prompt: int = 1, ): """ Generate or edit images. Args: prompt: Text prompt for generation/editing. image: Optional input image(s) for editing. If None, does text-to-image. negative_prompt: Negative prompt for CFG. height: Output image height. width: Output image width. num_inference_steps: Number of denoising steps. guidance_scale: CFG guidance scale. seed: Random seed for reproducibility. num_images_per_prompt: Number of images to generate per prompt. Returns: SimpleNamespace with .images attribute (list of PIL Images). """ from types import SimpleNamespace self._load_extras() offload = self._cpu_offload gpu = self._gpu_device if isinstance(prompt, str): prompt = [prompt] b = len(prompt) * num_images_per_prompt prompt = prompt * num_images_per_prompt cfg_prompt = [negative_prompt] * b generator = None if seed is not None: generator = torch.Generator(device=gpu).manual_seed(seed) # === Stage 1: VLM + Connector === if offload: self._offload_to(self.lmm, gpu) self._offload_to(self.connector_module, gpu) pixel_values_src = None cond_latents = None if image is not None: if isinstance(image, Image.Image): image = [image] ref_images = [] for img in image: img = img.convert('RGB').resize((width, height)) pv = torch.from_numpy(np.array(img)).float() / 255.0 pv = 2 * pv - 1 pv = rearrange(pv, 'h w c -> c h w') ref_images.append(pv.to(dtype=self.transformer.dtype, device=gpu)) pixel_values_src = [[img for img in ref_images]] * b num_refs = [len(ref_images)] * b image_embeds, image_grid_thw = self.get_semantic_features_dynamic( [img for ref_imgs in pixel_values_src for img in ref_imgs]) ref_lens = [len(x) for x in image_embeds] text_inputs = self.prepare_image2image_prompts( prompt + cfg_prompt, num_refs=num_refs * 2, ref_lens=ref_lens * 2) text_inputs.update( image_embeds=torch.cat(image_embeds * 2), image_grid_thw=torch.cat([image_grid_thw] * 2)) if offload: self._offload_to(self.vae, gpu) cond_latents = [[self.pixels_to_latents(img[None])[0] for img in ref_imgs] for ref_imgs in pixel_values_src] cond_latents = cond_latents * 2 if offload: self._offload_to(self.vae, "cpu") else: text_inputs = self.prepare_text2image_prompts(prompt + cfg_prompt) hidden_states = self.connector_module.meta_queries[None].expand( 2 * b, self.num_queries, -1) inputs = self.prepare_forward_input(query_embeds=hidden_states, **text_inputs) output = self.llm(**inputs, return_dict=True, output_hidden_states=True) # SCB: extract multi-layer hidden states hidden_states = output.hidden_states num_layers = len(hidden_states) - 1 selected_layers = list(range(num_layers - 1, 0, -6)) selected_hiddens = [hidden_states[i] for i in selected_layers] merged_hidden = torch.cat(selected_hiddens, dim=-1) pooled_out, seq_out = self.connector_module.llm2dit(merged_hidden) if offload: del output, hidden_states, selected_hiddens, merged_hidden self._offload_to(self.lmm, "cpu") self._offload_to(self.connector_module, "cpu") # === Stage 2: DiT denoising === if offload: self._offload_to(self.transformer, gpu) pipeline = _SD3Pipeline( transformer=self.transformer, scheduler=self.scheduler, vae=self.vae, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, text_encoder_3=None, tokenizer_3=None) samples = pipeline( height=height, width=width, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, prompt_embeds=seq_out[:b], pooled_prompt_embeds=pooled_out[:b], negative_prompt_embeds=seq_out[b:], negative_pooled_prompt_embeds=pooled_out[b:], generator=generator, output_type='latent', cond_latents=cond_latents, ).images.to(self.transformer.dtype) if offload: self._offload_to(self.transformer, "cpu") # === Stage 3: VAE decode === if offload: self._offload_to(self.vae, gpu) pixels = self.latents_to_pixels(samples) if offload: self._offload_to(self.vae, "cpu") images = [] for i in range(pixels.shape[0]): img = pixels[i] img = rearrange(img, 'c h w -> h w c') img = torch.clamp(127.5 * img + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() images.append(Image.fromarray(img)) return SimpleNamespace(images=images)