| """ |
| VAE Wrappers — corrected for actual TAESD and SD-VAE APIs. |
| |
| TAESD (AutoencoderTiny): |
| - encode(x) returns AutoencoderTinyOutput with .latents (no sampling) |
| - scaling_factor = 1.0 (no scaling needed) |
| - decode(z) returns DecoderOutput with .sample |
| |
| SD-VAE (AutoencoderKL): |
| - encode(x) returns AutoEncoderKLOutput with .latent_dist |
| - scaling_factor = 0.18215 |
| - decode(z) returns DecoderOutput with .sample |
| """ |
|
|
| import torch |
|
|
|
|
| class TAESDWrapper: |
| """ |
| Wrapper for Tiny AutoEncoder for Stable Diffusion (TAESD). |
| |
| Key: TAESD uses .latents directly (deterministic encoder, no sampling). |
| scaling_factor = 1.0, so no scaling needed. |
| |
| Model: madebyollin/taesd (~2.5M params, 9.8MB) |
| """ |
| |
| @staticmethod |
| def load(device='cpu'): |
| """Load TAESD model from HuggingFace.""" |
| from diffusers import AutoencoderTiny |
| model = AutoencoderTiny.from_pretrained( |
| "madebyollin/taesd", |
| torch_dtype=torch.float32, |
| ) |
| model = model.to(device) |
| model.eval() |
| return model |
| |
| @staticmethod |
| def encode(vae, x): |
| """ |
| Encode image to latent. |
| Args: |
| vae: AutoencoderTiny model |
| x: [B, 3, H, W] images in [-1, 1] |
| Returns: |
| z: [B, 4, H/8, W/8] latents |
| """ |
| with torch.no_grad(): |
| |
| z = vae.encode(x).latents |
| return z |
| |
| @staticmethod |
| def decode(vae, z): |
| """ |
| Decode latent to image. |
| Args: |
| vae: AutoencoderTiny model |
| z: [B, 4, H/8, W/8] latents |
| Returns: |
| x: [B, 3, H, W] images in [-1, 1] |
| """ |
| with torch.no_grad(): |
| x = vae.decode(z).sample |
| return x |
| |
| @staticmethod |
| def get_latent_shape(image_size): |
| """Get latent spatial size (8x compression).""" |
| return image_size // 8 |
|
|
|
|
| class SDVAEWrapper: |
| """ |
| Wrapper for Stability AI VAE (sd-vae-ft-mse). |
| |
| Key: Uses .latent_dist.sample() and scaling_factor=0.18215. |
| |
| Model: stabilityai/sd-vae-ft-mse (~84M params) |
| """ |
| |
| @staticmethod |
| def load(device='cpu'): |
| """Load SD VAE model.""" |
| from diffusers import AutoencoderKL |
| model = AutoencoderKL.from_pretrained( |
| "stabilityai/sd-vae-ft-mse", |
| torch_dtype=torch.float32, |
| ) |
| model = model.to(device) |
| model.eval() |
| return model |
| |
| @staticmethod |
| def encode(vae, x): |
| """Encode image to latent (with scaling).""" |
| with torch.no_grad(): |
| posterior = vae.encode(x).latent_dist |
| z = posterior.sample() |
| z = z * vae.config.scaling_factor |
| return z |
| |
| @staticmethod |
| def decode(vae, z): |
| """Decode latent to image (with unscaling).""" |
| with torch.no_grad(): |
| z = z / vae.config.scaling_factor |
| x = vae.decode(z).sample |
| return x |
|
|