tiny-flux-deep / scripts /inference_v4.py
AbstractPhil's picture
Create inference_v4.py
bf792e2 verified
import torch
from transformers import T5EncoderModel, T5Tokenizer, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# Load text encoders
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
t5_model = T5EncoderModel.from_pretrained("google/flan-t5-base").to("cuda", torch.bfloat16)
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda", torch.bfloat16)
# Load VAE
vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
subfolder="vae",
torch_dtype=torch.bfloat16
).to("cuda")
# Load TinyFlux-Deep
model_py = hf_hub_download("AbstractPhil/tiny-flux-deep", "scripts/model_v4.py")
exec(open(model_py).read())
config = TinyFluxConfig(
use_sol_prior=False, # Disabled until trained
use_t5_vec=False, # Disabled until trained
)
model = TinyFluxDeep(config).to("cuda", torch.bfloat16)
weights = load_file(hf_hub_download("AbstractPhil/tiny-flux-deep", "checkpoint_runs/v4_init/lailah_401434_v4_init.safetensors"))
model.load_state_dict(weights, strict=False)
model.eval()
def encode_prompt(prompt):
"""Encode prompt with both T5 and CLIP."""
# T5
t5_tokens = t5_tokenizer(prompt, return_tensors="pt", padding="max_length",
max_length=77, truncation=True).to("cuda")
with torch.no_grad():
t5_emb = t5_model(**t5_tokens).last_hidden_state.to(torch.bfloat16)
# CLIP
clip_tokens = clip_tokenizer(prompt, return_tensors="pt", padding="max_length",
max_length=77, truncation=True).to("cuda")
with torch.no_grad():
clip_out = clip_model(**clip_tokens)
clip_pooled = clip_out.pooler_output.to(torch.bfloat16)
return t5_emb, clip_pooled
def flux_shift(t, s=3.0):
"""Flux-style timestep shift."""
return s * t / (1 + (s - 1) * t)
@torch.inference_mode()
def generate_image(prompt, num_steps=25, cfg_scale=4.0, seed=None):
"""
Euler sampling for rectified flow.
Flow matching formulation:
x_t = (1 - t) * noise + t * data
At t=0: pure noise
At t=1: pure data
Velocity v = data - noise (constant)
Sampling: Integrate from t=0 (noise) → t=1 (data)
"""
if seed is not None:
torch.manual_seed(seed)
t5_emb, clip_pooled = encode_prompt(prompt)
t5_null, clip_null = encode_prompt("")
# Start from pure noise (t=0)
x = torch.randn(1, 64*64, 16, device="cuda", dtype=torch.bfloat16)
img_ids = TinyFluxDeep.create_img_ids(1, 64, 64, "cuda")
# Timesteps: 0 → 1 with Flux shift
t_linear = torch.linspace(0, 1, num_steps + 1, device="cuda", dtype=torch.float32)
timesteps = flux_shift(t_linear, s=3.0)
for i in range(num_steps):
t_curr = timesteps[i]
t_next = timesteps[i + 1]
dt = t_next - t_curr # Positive, moving toward data
t_batch = t_curr.unsqueeze(0)
# Predict velocity
v_cond = model(x, t5_emb, clip_pooled, t_batch, img_ids)
v_uncond = model(x, t5_null, clip_null, t_batch, img_ids)
# Classifier-free guidance
v = v_uncond + cfg_scale * (v_cond - v_uncond)
# Euler step: x_{t+dt} = x_t + v * dt
x = x + v * dt
# Decode with VAE
x = x.reshape(1, 64, 64, 16).permute(0, 3, 1, 2) # [B, C, H, W]
x = x / vae.config.scaling_factor
image = vae.decode(x).sample
# Convert to PIL
image = (image / 2 + 0.5).clamp(0, 1)
image = image[0].permute(1, 2, 0).cpu().float().numpy()
image = (image * 255).astype("uint8")
from PIL import Image
return Image.fromarray(image)
# Generate
image = generate_image("a photograph of a tiger in natural habitat", seed=42)
image.save("tiger.png")
image