LiquidFlow-Gen / liquid_flow /vae_wrapper.py
krystv's picture
Upload liquid_flow/vae_wrapper.py
3798d56 verified
"""
VAE Wrappers — corrected for actual TAESD and SD-VAE APIs.
TAESD (AutoencoderTiny):
- encode(x) returns AutoencoderTinyOutput with .latents (no sampling)
- scaling_factor = 1.0 (no scaling needed)
- decode(z) returns DecoderOutput with .sample
SD-VAE (AutoencoderKL):
- encode(x) returns AutoEncoderKLOutput with .latent_dist
- scaling_factor = 0.18215
- decode(z) returns DecoderOutput with .sample
"""
import torch
class TAESDWrapper:
"""
Wrapper for Tiny AutoEncoder for Stable Diffusion (TAESD).
Key: TAESD uses .latents directly (deterministic encoder, no sampling).
scaling_factor = 1.0, so no scaling needed.
Model: madebyollin/taesd (~2.5M params, 9.8MB)
"""
@staticmethod
def load(device='cpu'):
"""Load TAESD model from HuggingFace."""
from diffusers import AutoencoderTiny
model = AutoencoderTiny.from_pretrained(
"madebyollin/taesd",
torch_dtype=torch.float32,
)
model = model.to(device)
model.eval()
return model
@staticmethod
def encode(vae, x):
"""
Encode image to latent.
Args:
vae: AutoencoderTiny model
x: [B, 3, H, W] images in [-1, 1]
Returns:
z: [B, 4, H/8, W/8] latents
"""
with torch.no_grad():
# TAESD returns .latents directly (no latent_dist)
z = vae.encode(x).latents
return z
@staticmethod
def decode(vae, z):
"""
Decode latent to image.
Args:
vae: AutoencoderTiny model
z: [B, 4, H/8, W/8] latents
Returns:
x: [B, 3, H, W] images in [-1, 1]
"""
with torch.no_grad():
x = vae.decode(z).sample
return x
@staticmethod
def get_latent_shape(image_size):
"""Get latent spatial size (8x compression)."""
return image_size // 8
class SDVAEWrapper:
"""
Wrapper for Stability AI VAE (sd-vae-ft-mse).
Key: Uses .latent_dist.sample() and scaling_factor=0.18215.
Model: stabilityai/sd-vae-ft-mse (~84M params)
"""
@staticmethod
def load(device='cpu'):
"""Load SD VAE model."""
from diffusers import AutoencoderKL
model = AutoencoderKL.from_pretrained(
"stabilityai/sd-vae-ft-mse",
torch_dtype=torch.float32,
)
model = model.to(device)
model.eval()
return model
@staticmethod
def encode(vae, x):
"""Encode image to latent (with scaling)."""
with torch.no_grad():
posterior = vae.encode(x).latent_dist
z = posterior.sample()
z = z * vae.config.scaling_factor
return z
@staticmethod
def decode(vae, z):
"""Decode latent to image (with unscaling)."""
with torch.no_grad():
z = z / vae.config.scaling_factor
x = vae.decode(z).sample
return x