Text-to-Image
Diffusers
Safetensors
DeepGen-1.0-diffusers / deepgen_pipeline.py
rhli's picture
Upload folder using huggingface_hub
85c2ed2 verified
"""
DeepGen Diffusers Pipeline - Standalone pipeline for DeepGen-1.0.
This file is self-contained and does not require the DeepGen repository.
It can be used with `trust_remote_code=True` when loading from HuggingFace Hub.
Usage:
import torch
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(
"deepgenteam/DeepGen-1.0-diffusers",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
pipe.to("cuda")
# Text-to-Image
image = pipe("a racoon holding a shiny red apple", height=512, width=512).images[0]
# Image Edit
from PIL import Image
image = pipe("Place this guitar on a sandy beach.",
image=Image.open("guitar.png"), height=512, width=512).images[0]
"""
import inspect
import math
import os
import json
import warnings
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn.init import _calculate_fan_in_and_fan_out
from torch.nn.utils.rnn import pad_sequence
from einops import rearrange
from PIL import Image
from safetensors.torch import load_file
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import (
FromOriginalModelMixin,
FromSingleFileMixin,
PeftAdapterMixin,
SD3IPAdapterMixin,
SD3LoraLoaderMixin,
SD3Transformer2DLoadersMixin,
)
from diffusers.models.attention import FeedForward, JointTransformerBlock, _chunked_feed_forward
from diffusers.models.attention_processor import (
Attention,
AttentionProcessor,
FusedJointAttnProcessor2_0,
JointAttnProcessor2_0,
)
from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
from diffusers.utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import maybe_allow_in_graph, randn_tensor
from transformers import (
AutoTokenizer,
CLIPTextModelWithProjection,
CLIPTokenizer,
Qwen2_5_VLForConditionalGeneration,
SiglipImageProcessor,
SiglipVisionModel,
T5EncoderModel,
T5TokenizerFast,
)
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import (
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
)
if is_flash_attn_2_available():
from transformers.modeling_flash_attention_utils import _flash_attention_forward
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__)
IMAGE_MEAN = (0.48145466, 0.4578275, 0.40821073)
IMAGE_STD = (0.26862954, 0.26130258, 0.27577711)
# =============================================================================
# Connector: Config + Attention + MLP + Encoder
# =============================================================================
class ConnectorConfig(PretrainedConfig):
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
def _trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x):
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.", stacklevel=2)
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
tensor.uniform_(2 * l - 1, 2 * u - 1)
tensor.erfinv_()
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
tensor.clamp_(min=a, max=b)
def trunc_normal_tf_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
denom = {"fan_in": fan_in, "fan_out": fan_out, "fan_avg": (fan_in + fan_out) / 2}[mode]
variance = scale / denom
if distribution == "truncated_normal":
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
with torch.no_grad():
tensor.uniform_(-bound, bound)
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal")
class ConnectorAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} "
f"and `num_heads`: {self.num_heads}).")
self.scale = self.head_dim ** -0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous().reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class ConnectorFlashAttention2(ConnectorAttention):
is_causal = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
output_attentions = False
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.dropout if self.training else 0.0
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len,
dropout=dropout_rate, is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask)
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, None
class ConnectorSdpaAttention(ConnectorAttention):
is_causal = False
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
if output_attentions:
return super().forward(hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions)
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
is_causal = True if self.is_causal and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, None
CONNECTOR_ATTENTION_CLASSES = {
"eager": ConnectorAttention,
"flash_attention_2": ConnectorFlashAttention2,
"sdpa": ConnectorSdpaAttention,
}
class ConnectorMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states):
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
def _init_connector_weights(module):
if isinstance(module, nn.Embedding):
default_flax_embed_init(module.weight)
elif isinstance(module, ConnectorAttention):
nn.init.xavier_uniform_(module.q_proj.weight)
nn.init.xavier_uniform_(module.k_proj.weight)
nn.init.xavier_uniform_(module.v_proj.weight)
nn.init.xavier_uniform_(module.out_proj.weight)
nn.init.zeros_(module.q_proj.bias)
nn.init.zeros_(module.k_proj.bias)
nn.init.zeros_(module.v_proj.bias)
nn.init.zeros_(module.out_proj.bias)
elif isinstance(module, ConnectorMLP):
nn.init.xavier_uniform_(module.fc1.weight)
nn.init.xavier_uniform_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class ConnectorEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = CONNECTOR_ATTENTION_CLASSES[config._attn_implementation](config=config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = ConnectorMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask, output_attentions=False):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class ConnectorEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layers = nn.ModuleList([ConnectorEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self.apply(_init_connector_weights)
def forward(self, inputs_embeds):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
if self.gradient_checkpointing and self.training:
layer_outputs = torch.utils.checkpoint.checkpoint(
encoder_layer.__call__, hidden_states, None, False, use_reentrant=False)
else:
layer_outputs = encoder_layer(hidden_states, None, output_attentions=False)
hidden_states = layer_outputs[0]
return hidden_states
class DeepGenConnector(nn.Module):
"""Connector module bridging VLM hidden states to DiT conditioning."""
def __init__(self, connector_config, num_queries, llm_hidden_size,
projector_1_in, projector_1_out,
projector_2_in, projector_2_out,
projector_3_in, projector_3_out):
super().__init__()
self.connector = ConnectorEncoder(ConnectorConfig(**connector_config))
self.projector_1 = nn.Linear(projector_1_in, projector_1_out)
self.projector_2 = nn.Linear(projector_2_in, projector_2_out)
self.projector_3 = nn.Linear(projector_3_in, projector_3_out)
self.meta_queries = nn.Parameter(torch.zeros(num_queries, llm_hidden_size))
self.num_queries = num_queries
def llm2dit(self, x):
x = self.connector(self.projector_1(x))
pooled_out = self.projector_2(x.mean(1))
seq_out = self.projector_3(x)
return pooled_out, seq_out
# =============================================================================
# Custom SD3 Transformer (dynamic resolution + attention mask)
# =============================================================================
class CustomJointAttnProcessor2_0:
"""Attention processor supporting attention masks for dynamic-resolution SD3."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CustomJointAttnProcessor2_0 requires PyTorch 2.0+")
def __call__(self, attn, hidden_states, encoder_hidden_states=None,
attention_mask=None, *args, **kwargs):
residual = hidden_states
batch_size = hidden_states.shape[0]
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
if encoder_hidden_states is not None:
ctx_len = encoder_hidden_states.shape[1]
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states).view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
if attention_mask is not None:
encoder_attention_mask = torch.ones(
batch_size, ctx_len, dtype=torch.bool, device=hidden_states.device)
attention_mask = torch.cat([attention_mask, encoder_attention_mask], dim=1)
if attention_mask is not None:
attention_mask = attention_mask[:, None] * attention_mask[..., None]
indices = range(attention_mask.shape[1])
attention_mask[:, indices, indices] = True
attention_mask = attention_mask[:, None]
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = (
hidden_states[:, :residual.shape[1]],
hidden_states[:, residual.shape[1]:])
if not attn.context_pre_only:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
class CustomJointTransformerBlock(JointTransformerBlock):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attn.set_processor(CustomJointAttnProcessor2_0())
if self.attn2 is not None:
self.attn2.set_processor(CustomJointAttnProcessor2_0())
def forward(self, hidden_states, encoder_hidden_states, temb,
attention_mask=None, joint_attention_kwargs=None):
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(hidden_states, emb=temb)
else:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
if self.context_pre_only:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
else:
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb)
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, attention_mask=attention_mask,
encoder_hidden_states=norm_encoder_hidden_states, **joint_attention_kwargs)
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
if self.use_dual_attention:
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, attention_mask=attention_mask, **joint_attention_kwargs)
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
hidden_states = hidden_states + attn_output2
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
if self.context_pre_only:
encoder_hidden_states = None
else:
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
if self._chunk_size is not None:
context_ff_output = _chunked_feed_forward(self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size)
else:
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
return encoder_hidden_states, hidden_states
class SD3Transformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
):
_supports_gradient_checkpointing = True
_no_split_modules = ["JointTransformerBlock", "CustomJointTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config
def __init__(
self,
sample_size: int = 128,
patch_size: int = 2,
in_channels: int = 16,
num_layers: int = 18,
attention_head_dim: int = 64,
num_attention_heads: int = 18,
joint_attention_dim: int = 4096,
caption_projection_dim: int = 1152,
pooled_projection_dim: int = 2048,
out_channels: int = 16,
pos_embed_max_size: int = 96,
dual_attention_layers: Tuple[int, ...] = (),
qk_norm: Optional[str] = None,
):
super().__init__()
self.out_channels = out_channels if out_channels is not None else in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = PatchEmbed(
height=sample_size, width=sample_size, patch_size=patch_size,
in_channels=in_channels, embed_dim=self.inner_dim,
pos_embed_max_size=pos_embed_max_size)
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim)
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
self.transformer_blocks = nn.ModuleList([
CustomJointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
context_pre_only=i == num_layers - 1,
qk_norm=qk_norm,
use_dual_attention=True if i in dual_attention_layers else False,
) for i in range(num_layers)
])
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
@property
def attn_processors(self):
processors = {}
def fn_recursive_add_processors(name, module, processors):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor):
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(f"A dict of processors was passed, but the number of processors {len(processor)} does not match the number of attention layers: {count}.")
def fn_recursive_attn_processor(name, module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward(
self,
hidden_states,
encoder_hidden_states=None,
cond_hidden_states=None,
pooled_projections=None,
timestep=None,
block_controlnet_hidden_states=None,
joint_attention_kwargs=None,
return_dict=True,
skip_layers=None,
):
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective.")
latent_sizes = [hs.shape[-2:] for hs in hidden_states]
bsz = len(hidden_states)
hidden_states_list = []
for idx in range(bsz):
hidden_states_per_sample = self.pos_embed(hidden_states[idx][None])[0]
if cond_hidden_states is not None:
for ref in cond_hidden_states[idx]:
hidden_states_per_sample = torch.cat(
[hidden_states_per_sample, self.pos_embed(ref[None])[0]])
hidden_states_list.append(hidden_states_per_sample)
max_len = max([len(hs) for hs in hidden_states_list])
attention_mask = torch.zeros(bsz, max_len, dtype=torch.bool, device=self.device)
for i, hs in enumerate(hidden_states_list):
attention_mask[i, :len(hs)] = True
hidden_states = pad_sequence(hidden_states_list, batch_first=True, padding_value=0.0, padding_side='right')
temb = self.time_text_embed(timestep, pooled_projections)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep)
joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb)
for index_block, block in enumerate(self.transformer_blocks):
is_skip = True if skip_layers is not None and index_block in skip_layers else False
if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, attention_mask, joint_attention_kwargs)
elif not is_skip:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states,
temb=temb, attention_mask=attention_mask, joint_attention_kwargs=joint_attention_kwargs)
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)]
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
patch_size = self.config.patch_size
latent_sizes = [(ls[0] // patch_size, ls[1] // patch_size) for ls in latent_sizes]
output = [rearrange(hs[:math.prod(latent_size)], '(h w) (p q c) -> c (h p) (w q)',
h=latent_size[0], w=latent_size[1], p=patch_size, q=patch_size)
for hs, latent_size in zip(hidden_states, latent_sizes)]
try:
output = torch.stack(output)
except:
pass
if USE_PEFT_BACKEND:
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
# =============================================================================
# Custom StableDiffusion3Pipeline (with cond_latents + dynamic shift)
# =============================================================================
def calculate_shift(image_seq_len, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
def retrieve_timesteps(scheduler, num_inference_steps=None, device=None, timesteps=None, sigmas=None, **kwargs):
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom timestep schedules.")
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom sigmas schedules.")
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class _SD3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
"""Internal SD3 pipeline with cond_latents support."""
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
_optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
def __init__(self, transformer, scheduler, vae, text_encoder, tokenizer,
text_encoder_2, tokenizer_2, text_encoder_3, tokenizer_3,
image_encoder=None, feature_extractor=None):
super().__init__()
self.register_modules(
vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2,
text_encoder_3=text_encoder_3, tokenizer=tokenizer, tokenizer_2=tokenizer_2,
tokenizer_3=tokenizer_3, transformer=transformer, scheduler=scheduler,
image_encoder=image_encoder, feature_extractor=feature_extractor)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
self.default_sample_size = self.transformer.config.sample_size if hasattr(self, "transformer") and self.transformer is not None else 128
self.patch_size = self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
def check_inputs(self, prompt, prompt_2, prompt_3, height, width, negative_prompt=None,
negative_prompt_2=None, negative_prompt_3=None, prompt_embeds=None,
negative_prompt_embeds=None, pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None,
max_sequence_length=None):
if height % (self.vae_scale_factor * self.patch_size) != 0 or width % (self.vae_scale_factor * self.patch_size) != 0:
raise ValueError(f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size}.")
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError("If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed.")
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
raise ValueError("If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed.")
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
if latents is not None:
return latents.to(device=device, dtype=dtype)
shape = (batch_size, num_channels_latents, int(height) // self.vae_scale_factor, int(width) // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(f"You have passed a list of generators of length {len(generator)}, but requested an effective batch size of {batch_size}.")
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@torch.no_grad()
def __call__(
self,
prompt=None, prompt_2=None, prompt_3=None,
height=None, width=None, num_inference_steps=28, sigmas=None,
guidance_scale=7.0,
negative_prompt=None, negative_prompt_2=None, negative_prompt_3=None,
num_images_per_prompt=1, generator=None, latents=None,
cond_latents=None,
prompt_embeds=None, negative_prompt_embeds=None,
pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None,
output_type="pil", return_dict=True,
joint_attention_kwargs=None, callback_on_step_end=None,
callback_on_step_end_tensor_inputs=["latents"],
max_sequence_length=256, mu=None, **kwargs,
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
self.check_inputs(prompt, prompt_2, prompt_3, height, width,
negative_prompt=negative_prompt, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
(prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds) = (
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt, num_channels_latents, height, width,
prompt_embeds.dtype, device, generator, latents)
scheduler_kwargs = {}
if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
_, _, h, w = latents.shape
image_seq_len = (h // self.transformer.config.patch_size) * (w // self.transformer.config.patch_size)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16))
scheduler_kwargs["mu"] = mu
elif mu is not None:
scheduler_kwargs["mu"] = mu
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
if cond_latents is not None and self.do_classifier_free_guidance:
if len(cond_latents) == latents.shape[0]:
cond_latents = cond_latents * 2
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self._interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input, cond_hidden_states=cond_latents,
timestep=timestep, encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False)[0]
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return StableDiffusion3PipelineOutput(images=image)
# =============================================================================
# DeepGen Pipeline (main entry point)
# =============================================================================
class DeepGenPipeline(DiffusionPipeline):
"""
DeepGen 1.0 Pipeline for text-to-image generation and image editing.
This pipeline integrates Qwen2.5-VL (VLM) + SCB Connector + SD3 DiT into a
single interface. Standard diffusers components (transformer, vae, scheduler)
are loaded by DiffusionPipeline; non-standard components (VLM, connector,
tokenizer, prompt_template) are loaded automatically on first use.
Usage:
pipe = DiffusionPipeline.from_pretrained(
"deepgenteam/DeepGen-1.0-diffusers",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
pipe.to("cuda")
result = pipe("a raccoon holding an apple", height=512, width=512)
result.images[0].save("output.png")
"""
_optional_components = []
def __init__(
self,
transformer: SD3Transformer2DModel,
vae: AutoencoderKL,
scheduler: FlowMatchEulerDiscreteScheduler,
):
super().__init__()
self.register_modules(
transformer=transformer,
vae=vae,
scheduler=scheduler,
)
self._upgrade_transformer()
self._extras_loaded = False
self._cpu_offload = False
self._gpu_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.lmm = None
self.tokenizer = None
self.connector_module = None
self.prompt_template = None
self.max_length = 1024
self.image_token_id = None
self.vit_mean = torch.tensor(IMAGE_MEAN)
self.vit_std = torch.tensor(IMAGE_STD)
def _upgrade_transformer(self):
"""Convert standard diffusers SD3Transformer2DModel to custom version
with cond_latents support for image editing. No weight copying needed."""
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel as _OrigSD3
if isinstance(self.transformer, _OrigSD3) and not isinstance(self.transformer, SD3Transformer2DModel):
self.transformer.__class__ = SD3Transformer2DModel
for block in self.transformer.transformer_blocks:
block.__class__ = CustomJointTransformerBlock
block.attn.set_processor(CustomJointAttnProcessor2_0())
if block.attn2 is not None:
block.attn2.set_processor(CustomJointAttnProcessor2_0())
def _resolve_pretrained_path(self):
path = self.config._name_or_path
if os.path.isdir(path):
return path
from huggingface_hub import snapshot_download
return snapshot_download(repo_id=path)
def _load_extras(self, vlm_model_path=None, attn_implementation="flash_attention_2"):
"""Load non-standard components (VLM, connector, tokenizer, prompt_template)."""
if self._extras_loaded:
return
path = self._resolve_pretrained_path()
dtype = next(self.transformer.parameters()).dtype
model_index_path = os.path.join(path, "model_index.json")
extra_cfg = {}
if os.path.isfile(model_index_path):
with open(model_index_path, "r") as f:
extra_cfg = json.load(f)
# Resolve VLM path: prefer local merged VLM (with LoRA baked in)
vlm_path = vlm_model_path
if vlm_path is None:
local_merged = os.path.join(path, "vlm")
if os.path.isdir(local_merged):
vlm_path = local_merged
else:
vlm_path = extra_cfg.get("vlm", "Qwen/Qwen2.5-VL-3B-Instruct")
if not os.path.isdir(vlm_path):
local_candidate = os.path.join("/data/huggingface", vlm_path.split("/")[-1])
if os.path.isdir(local_candidate):
vlm_path = local_candidate
print(f"Loading VLM from {vlm_path}...")
try:
self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
vlm_path, torch_dtype=dtype, attn_implementation=attn_implementation)
except Exception:
self.lmm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
vlm_path, torch_dtype=dtype, attn_implementation="sdpa")
self.lmm.requires_grad_(False)
print("Loading tokenizer...")
tokenizer_path = os.path.join(path, "tokenizer")
if os.path.isdir(tokenizer_path):
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, trust_remote_code=True, padding_side='right')
else:
self.tokenizer = AutoTokenizer.from_pretrained(
vlm_path, trust_remote_code=True, padding_side='right')
print("Loading connector...")
connector_dir = os.path.join(path, "connector")
with open(os.path.join(connector_dir, "config.json"), "r") as f:
connector_cfg = json.load(f)
conn_cfg = connector_cfg["connector"].copy()
conn_cfg["_attn_implementation"] = "sdpa"
self.connector_module = DeepGenConnector(
connector_config=conn_cfg,
num_queries=connector_cfg["num_queries"],
llm_hidden_size=connector_cfg["llm_hidden_size"],
projector_1_in=connector_cfg["projector_1_in"],
projector_1_out=connector_cfg["projector_1_out"],
projector_2_in=connector_cfg["projector_2_in"],
projector_2_out=connector_cfg["projector_2_out"],
projector_3_in=connector_cfg["projector_3_in"],
projector_3_out=connector_cfg["projector_3_out"],
)
connector_state = load_file(os.path.join(connector_dir, "model.safetensors"))
self.connector_module.load_state_dict(connector_state, strict=True)
self.connector_module = self.connector_module.to(dtype=dtype)
prompt_template_path = os.path.join(path, "prompt_template.json")
with open(prompt_template_path, "r") as f:
self.prompt_template = json.load(f)
self.max_length = connector_cfg.get("max_length", 1024)
self.image_token_id = self.tokenizer.convert_tokens_to_ids(
self.prompt_template['IMG_CONTEXT_TOKEN'])
if not self._cpu_offload:
device = self._gpu_device
self.lmm = self.lmm.to(device=device)
self.connector_module = self.connector_module.to(device=device, dtype=dtype)
self.vit_mean = self.vit_mean.to(device=self._gpu_device)
self.vit_std = self.vit_std.to(device=self._gpu_device)
self._extras_loaded = True
print("All components loaded.")
@property
def llm(self):
return self.lmm.language_model
@property
def num_queries(self):
return self.connector_module.num_queries
def to(self, *args, **kwargs):
result = super().to(*args, **kwargs)
device = None
dtype = None
for a in args:
if isinstance(a, torch.device):
device = a
elif isinstance(a, str):
device = torch.device(a)
elif isinstance(a, torch.dtype):
dtype = a
device = device or kwargs.get("device")
dtype = dtype or kwargs.get("dtype")
if device is not None:
self._gpu_device = device
if self._extras_loaded:
if device is not None:
self.lmm = self.lmm.to(device=device)
self.connector_module = self.connector_module.to(device=device)
self.vit_mean = self.vit_mean.to(device=device)
self.vit_std = self.vit_std.to(device=device)
if dtype is not None:
self.lmm = self.lmm.to(dtype=dtype)
self.connector_module = self.connector_module.to(dtype=dtype)
return result
def enable_model_cpu_offload(self, gpu_id=None, device=None):
"""Enable sequential CPU offload to reduce GPU memory usage (~14GB)."""
self._cpu_offload = True
if device is not None:
self._gpu_device = torch.device(device) if isinstance(device, str) else device
elif gpu_id is not None:
self._gpu_device = torch.device(f"cuda:{gpu_id}")
self.transformer = self.transformer.to("cpu")
self.vae = self.vae.to("cpu")
if self._extras_loaded:
self.lmm = self.lmm.to("cpu")
self.connector_module = self.connector_module.to("cpu")
self.vit_mean = self.vit_mean.to(self._gpu_device)
self.vit_std = self.vit_std.to(self._gpu_device)
torch.cuda.empty_cache()
def _offload_to(self, module, device):
module.to(device)
if device == torch.device("cpu") or device == "cpu":
torch.cuda.empty_cache()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Load the full pipeline. When called directly (not via DiffusionPipeline),
loads all components immediately including VLM and connector.
"""
vlm_model_path = kwargs.pop("vlm_model_path", None)
attn_implementation = kwargs.pop("attn_implementation", "flash_attention_2")
pipe = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
pipe._load_extras(vlm_model_path=vlm_model_path,
attn_implementation=attn_implementation)
return pipe
@torch.no_grad()
def pixels_to_latents(self, x):
z = self.vae.encode(x).latent_dist.sample()
z = (z - self.vae.config.shift_factor) * self.vae.config.scaling_factor
return z
@torch.no_grad()
def latents_to_pixels(self, z):
z = (z / self.vae.config.scaling_factor) + self.vae.config.shift_factor
x_rec = self.vae.decode(z).sample
return x_rec
def prepare_text2image_prompts(self, texts):
texts = [self.prompt_template['GENERATION'].format(input=text) for text in texts]
texts = [self.prompt_template['INSTRUCTION'].format(input=text) for text in texts]
return self.tokenizer(
texts, add_special_tokens=True, return_tensors='pt',
padding=True, padding_side='left').to(self._gpu_device)
def prepare_image2image_prompts(self, texts, num_refs, ref_lens):
prompts = []
cnt = 0
for text, num_ref in zip(texts, num_refs):
image_tokens = ''
for _ in range(num_ref):
image_tokens += (self.prompt_template['IMG_START_TOKEN'] +
self.prompt_template['IMG_CONTEXT_TOKEN'] * ref_lens[cnt] +
self.prompt_template['IMG_END_TOKEN'])
cnt += 1
prompts.append(self.prompt_template['INSTRUCTION'].format(
input=f'{image_tokens}\n{text}'))
return self.tokenizer(
prompts, add_special_tokens=True, return_tensors='pt',
padding=True, padding_side='left').to(self._gpu_device)
def prepare_forward_input(self, query_embeds, input_ids=None,
image_embeds=None, image_grid_thw=None,
attention_mask=None, past_key_values=None):
b, l, _ = query_embeds.shape
attention_mask = attention_mask.to(device=self._gpu_device, dtype=torch.bool)
input_ids = torch.cat([input_ids, input_ids.new_zeros(b, l)], dim=1)
attention_mask = torch.cat([attention_mask, attention_mask.new_ones(b, l)], dim=1)
position_ids, _ = self.lmm.model.get_rope_index(
input_ids=input_ids, image_grid_thw=image_grid_thw,
video_grid_thw=None, second_per_grid_ts=None,
attention_mask=attention_mask)
if past_key_values is not None:
inputs_embeds = query_embeds
position_ids = position_ids[..., -l:]
else:
input_ids = input_ids[:, :-l]
if image_embeds is None:
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
else:
inputs_embeds = torch.zeros(
*input_ids.shape, self.llm.config.hidden_size,
device=self._gpu_device, dtype=self.transformer.dtype)
inputs_embeds[input_ids == self.image_token_id] = \
image_embeds.contiguous().view(-1, self.llm.config.hidden_size)
inputs_embeds[input_ids != self.image_token_id] = \
self.llm.get_input_embeddings()(input_ids[input_ids != self.image_token_id])
inputs_embeds = torch.cat([inputs_embeds, query_embeds], dim=1)
return dict(inputs_embeds=inputs_embeds, attention_mask=attention_mask,
position_ids=position_ids, past_key_values=past_key_values)
@torch.no_grad()
def get_semantic_features(self, pixel_values, resize=True):
pixel_values = (pixel_values + 1.0) / 2
pixel_values = pixel_values - self.vit_mean.view(1, 3, 1, 1)
pixel_values = pixel_values / self.vit_std.view(1, 3, 1, 1)
if resize:
pixel_values = F.interpolate(pixel_values, size=(448, 448), mode='bilinear')
b, c, h, w = pixel_values.shape
patch_size = self.lmm.config.vision_config.patch_size
spatial_merge_size = self.lmm.config.vision_config.spatial_merge_size
temporal_patch_size = self.lmm.config.vision_config.temporal_patch_size
pixel_values = pixel_values[:, None].expand(b, temporal_patch_size, c, h, w)
grid_t = 1
grid_h, grid_w = h // patch_size, w // patch_size
pixel_values = pixel_values.view(
b, grid_t, temporal_patch_size, c,
grid_h // spatial_merge_size, spatial_merge_size, patch_size,
grid_w // spatial_merge_size, spatial_merge_size, patch_size)
pixel_values = rearrange(
pixel_values, 'b t tp c h m p w n q -> (b t h w m n) (c tp p q)')
image_grid_thw = torch.tensor(
[(grid_t, grid_h, grid_w)] * b).to(self._gpu_device).long()
image_embeds = self.lmm.visual(pixel_values, grid_thw=image_grid_thw)
image_embeds = rearrange(image_embeds, '(b l) d -> b l d', b=b)
return image_embeds, image_grid_thw
@torch.no_grad()
def get_semantic_features_dynamic(self, pixel_values):
def multi_apply(func, *args, **kwargs):
pfunc = partial(func, **kwargs) if kwargs else func
map_results = map(pfunc, *args)
return tuple(map(list, zip(*map_results)))
pixel_values = [F.interpolate(p[None], scale_factor=28/32, mode='bilinear')
for p in pixel_values]
image_embeds, image_grid_thw = multi_apply(
self.get_semantic_features, pixel_values, resize=False)
image_embeds = [x[0] for x in image_embeds]
image_grid_thw = torch.cat(image_grid_thw, dim=0)
return image_embeds, image_grid_thw
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
image: Optional[Union[Image.Image, List[Image.Image]]] = None,
negative_prompt: str = "",
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 4.0,
seed: Optional[int] = None,
num_images_per_prompt: int = 1,
):
"""
Generate or edit images.
Args:
prompt: Text prompt for generation/editing.
image: Optional input image(s) for editing. If None, does text-to-image.
negative_prompt: Negative prompt for CFG.
height: Output image height.
width: Output image width.
num_inference_steps: Number of denoising steps.
guidance_scale: CFG guidance scale.
seed: Random seed for reproducibility.
num_images_per_prompt: Number of images to generate per prompt.
Returns:
SimpleNamespace with .images attribute (list of PIL Images).
"""
from types import SimpleNamespace
self._load_extras()
offload = self._cpu_offload
gpu = self._gpu_device
if isinstance(prompt, str):
prompt = [prompt]
b = len(prompt) * num_images_per_prompt
prompt = prompt * num_images_per_prompt
cfg_prompt = [negative_prompt] * b
generator = None
if seed is not None:
generator = torch.Generator(device=gpu).manual_seed(seed)
# === Stage 1: VLM + Connector ===
if offload:
self._offload_to(self.lmm, gpu)
self._offload_to(self.connector_module, gpu)
pixel_values_src = None
cond_latents = None
if image is not None:
if isinstance(image, Image.Image):
image = [image]
ref_images = []
for img in image:
img = img.convert('RGB').resize((width, height))
pv = torch.from_numpy(np.array(img)).float() / 255.0
pv = 2 * pv - 1
pv = rearrange(pv, 'h w c -> c h w')
ref_images.append(pv.to(dtype=self.transformer.dtype, device=gpu))
pixel_values_src = [[img for img in ref_images]] * b
num_refs = [len(ref_images)] * b
image_embeds, image_grid_thw = self.get_semantic_features_dynamic(
[img for ref_imgs in pixel_values_src for img in ref_imgs])
ref_lens = [len(x) for x in image_embeds]
text_inputs = self.prepare_image2image_prompts(
prompt + cfg_prompt, num_refs=num_refs * 2, ref_lens=ref_lens * 2)
text_inputs.update(
image_embeds=torch.cat(image_embeds * 2),
image_grid_thw=torch.cat([image_grid_thw] * 2))
if offload:
self._offload_to(self.vae, gpu)
cond_latents = [[self.pixels_to_latents(img[None])[0] for img in ref_imgs]
for ref_imgs in pixel_values_src]
cond_latents = cond_latents * 2
if offload:
self._offload_to(self.vae, "cpu")
else:
text_inputs = self.prepare_text2image_prompts(prompt + cfg_prompt)
hidden_states = self.connector_module.meta_queries[None].expand(
2 * b, self.num_queries, -1)
inputs = self.prepare_forward_input(query_embeds=hidden_states, **text_inputs)
output = self.llm(**inputs, return_dict=True, output_hidden_states=True)
# SCB: extract multi-layer hidden states
hidden_states = output.hidden_states
num_layers = len(hidden_states) - 1
selected_layers = list(range(num_layers - 1, 0, -6))
selected_hiddens = [hidden_states[i] for i in selected_layers]
merged_hidden = torch.cat(selected_hiddens, dim=-1)
pooled_out, seq_out = self.connector_module.llm2dit(merged_hidden)
if offload:
del output, hidden_states, selected_hiddens, merged_hidden
self._offload_to(self.lmm, "cpu")
self._offload_to(self.connector_module, "cpu")
# === Stage 2: DiT denoising ===
if offload:
self._offload_to(self.transformer, gpu)
pipeline = _SD3Pipeline(
transformer=self.transformer, scheduler=self.scheduler,
vae=self.vae, text_encoder=None, tokenizer=None,
text_encoder_2=None, tokenizer_2=None,
text_encoder_3=None, tokenizer_3=None)
samples = pipeline(
height=height, width=width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
prompt_embeds=seq_out[:b],
pooled_prompt_embeds=pooled_out[:b],
negative_prompt_embeds=seq_out[b:],
negative_pooled_prompt_embeds=pooled_out[b:],
generator=generator,
output_type='latent',
cond_latents=cond_latents,
).images.to(self.transformer.dtype)
if offload:
self._offload_to(self.transformer, "cpu")
# === Stage 3: VAE decode ===
if offload:
self._offload_to(self.vae, gpu)
pixels = self.latents_to_pixels(samples)
if offload:
self._offload_to(self.vae, "cpu")
images = []
for i in range(pixels.shape[0]):
img = pixels[i]
img = rearrange(img, 'c h w -> h w c')
img = torch.clamp(127.5 * img + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
images.append(Image.fromarray(img))
return SimpleNamespace(images=images)