""" VAE Wrappers — corrected for actual TAESD and SD-VAE APIs. TAESD (AutoencoderTiny): - encode(x) returns AutoencoderTinyOutput with .latents (no sampling) - scaling_factor = 1.0 (no scaling needed) - decode(z) returns DecoderOutput with .sample SD-VAE (AutoencoderKL): - encode(x) returns AutoEncoderKLOutput with .latent_dist - scaling_factor = 0.18215 - decode(z) returns DecoderOutput with .sample """ import torch class TAESDWrapper: """ Wrapper for Tiny AutoEncoder for Stable Diffusion (TAESD). Key: TAESD uses .latents directly (deterministic encoder, no sampling). scaling_factor = 1.0, so no scaling needed. Model: madebyollin/taesd (~2.5M params, 9.8MB) """ @staticmethod def load(device='cpu'): """Load TAESD model from HuggingFace.""" from diffusers import AutoencoderTiny model = AutoencoderTiny.from_pretrained( "madebyollin/taesd", torch_dtype=torch.float32, ) model = model.to(device) model.eval() return model @staticmethod def encode(vae, x): """ Encode image to latent. Args: vae: AutoencoderTiny model x: [B, 3, H, W] images in [-1, 1] Returns: z: [B, 4, H/8, W/8] latents """ with torch.no_grad(): # TAESD returns .latents directly (no latent_dist) z = vae.encode(x).latents return z @staticmethod def decode(vae, z): """ Decode latent to image. Args: vae: AutoencoderTiny model z: [B, 4, H/8, W/8] latents Returns: x: [B, 3, H, W] images in [-1, 1] """ with torch.no_grad(): x = vae.decode(z).sample return x @staticmethod def get_latent_shape(image_size): """Get latent spatial size (8x compression).""" return image_size // 8 class SDVAEWrapper: """ Wrapper for Stability AI VAE (sd-vae-ft-mse). Key: Uses .latent_dist.sample() and scaling_factor=0.18215. Model: stabilityai/sd-vae-ft-mse (~84M params) """ @staticmethod def load(device='cpu'): """Load SD VAE model.""" from diffusers import AutoencoderKL model = AutoencoderKL.from_pretrained( "stabilityai/sd-vae-ft-mse", torch_dtype=torch.float32, ) model = model.to(device) model.eval() return model @staticmethod def encode(vae, x): """Encode image to latent (with scaling).""" with torch.no_grad(): posterior = vae.encode(x).latent_dist z = posterior.sample() z = z * vae.config.scaling_factor return z @staticmethod def decode(vae, z): """Decode latent to image (with unscaling).""" with torch.no_grad(): z = z / vae.config.scaling_factor x = vae.decode(z).sample return x