File size: 3,037 Bytes
b9e8cb3 3798d56 b9e8cb3 3798d56 b9e8cb3 3798d56 b9e8cb3 3798d56 b9e8cb3 3798d56 b9e8cb3 3798d56 b9e8cb3 3798d56 b9e8cb3 3798d56 b9e8cb3 3798d56 b9e8cb3 3798d56 b9e8cb3 3798d56 b9e8cb3 3798d56 b9e8cb3 3798d56 b9e8cb3 3798d56 b9e8cb3 3798d56 b9e8cb3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | """
VAE Wrappers — corrected for actual TAESD and SD-VAE APIs.
TAESD (AutoencoderTiny):
- encode(x) returns AutoencoderTinyOutput with .latents (no sampling)
- scaling_factor = 1.0 (no scaling needed)
- decode(z) returns DecoderOutput with .sample
SD-VAE (AutoencoderKL):
- encode(x) returns AutoEncoderKLOutput with .latent_dist
- scaling_factor = 0.18215
- decode(z) returns DecoderOutput with .sample
"""
import torch
class TAESDWrapper:
"""
Wrapper for Tiny AutoEncoder for Stable Diffusion (TAESD).
Key: TAESD uses .latents directly (deterministic encoder, no sampling).
scaling_factor = 1.0, so no scaling needed.
Model: madebyollin/taesd (~2.5M params, 9.8MB)
"""
@staticmethod
def load(device='cpu'):
"""Load TAESD model from HuggingFace."""
from diffusers import AutoencoderTiny
model = AutoencoderTiny.from_pretrained(
"madebyollin/taesd",
torch_dtype=torch.float32,
)
model = model.to(device)
model.eval()
return model
@staticmethod
def encode(vae, x):
"""
Encode image to latent.
Args:
vae: AutoencoderTiny model
x: [B, 3, H, W] images in [-1, 1]
Returns:
z: [B, 4, H/8, W/8] latents
"""
with torch.no_grad():
# TAESD returns .latents directly (no latent_dist)
z = vae.encode(x).latents
return z
@staticmethod
def decode(vae, z):
"""
Decode latent to image.
Args:
vae: AutoencoderTiny model
z: [B, 4, H/8, W/8] latents
Returns:
x: [B, 3, H, W] images in [-1, 1]
"""
with torch.no_grad():
x = vae.decode(z).sample
return x
@staticmethod
def get_latent_shape(image_size):
"""Get latent spatial size (8x compression)."""
return image_size // 8
class SDVAEWrapper:
"""
Wrapper for Stability AI VAE (sd-vae-ft-mse).
Key: Uses .latent_dist.sample() and scaling_factor=0.18215.
Model: stabilityai/sd-vae-ft-mse (~84M params)
"""
@staticmethod
def load(device='cpu'):
"""Load SD VAE model."""
from diffusers import AutoencoderKL
model = AutoencoderKL.from_pretrained(
"stabilityai/sd-vae-ft-mse",
torch_dtype=torch.float32,
)
model = model.to(device)
model.eval()
return model
@staticmethod
def encode(vae, x):
"""Encode image to latent (with scaling)."""
with torch.no_grad():
posterior = vae.encode(x).latent_dist
z = posterior.sample()
z = z * vae.config.scaling_factor
return z
@staticmethod
def decode(vae, z):
"""Decode latent to image (with unscaling)."""
with torch.no_grad():
z = z / vae.config.scaling_factor
x = vae.decode(z).sample
return x
|