File size: 3,037 Bytes
b9e8cb3
3798d56
b9e8cb3
3798d56
 
 
 
b9e8cb3
3798d56
 
 
 
b9e8cb3
 
 
 
 
 
 
 
 
3798d56
 
b9e8cb3
3798d56
b9e8cb3
 
 
 
3798d56
b9e8cb3
 
 
 
 
 
 
 
 
 
 
 
 
 
3798d56
b9e8cb3
 
3798d56
b9e8cb3
 
3798d56
 
b9e8cb3
 
 
 
 
 
 
3798d56
 
b9e8cb3
 
 
 
 
 
3798d56
 
 
 
 
b9e8cb3
 
 
 
 
 
3798d56
b9e8cb3
3798d56
b9e8cb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3798d56
b9e8cb3
 
 
 
 
 
 
 
3798d56
b9e8cb3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
VAE Wrappers — corrected for actual TAESD and SD-VAE APIs.

TAESD (AutoencoderTiny):
  - encode(x) returns AutoencoderTinyOutput with .latents (no sampling)
  - scaling_factor = 1.0 (no scaling needed)
  - decode(z) returns DecoderOutput with .sample

SD-VAE (AutoencoderKL):
  - encode(x) returns AutoEncoderKLOutput with .latent_dist
  - scaling_factor = 0.18215
  - decode(z) returns DecoderOutput with .sample
"""

import torch


class TAESDWrapper:
    """
    Wrapper for Tiny AutoEncoder for Stable Diffusion (TAESD).
    
    Key: TAESD uses .latents directly (deterministic encoder, no sampling).
    scaling_factor = 1.0, so no scaling needed.
    
    Model: madebyollin/taesd (~2.5M params, 9.8MB)
    """
    
    @staticmethod
    def load(device='cpu'):
        """Load TAESD model from HuggingFace."""
        from diffusers import AutoencoderTiny
        model = AutoencoderTiny.from_pretrained(
            "madebyollin/taesd",
            torch_dtype=torch.float32,
        )
        model = model.to(device)
        model.eval()
        return model
    
    @staticmethod
    def encode(vae, x):
        """
        Encode image to latent.
        Args:
            vae: AutoencoderTiny model
            x: [B, 3, H, W] images in [-1, 1]
        Returns:
            z: [B, 4, H/8, W/8] latents
        """
        with torch.no_grad():
            # TAESD returns .latents directly (no latent_dist)
            z = vae.encode(x).latents
        return z
    
    @staticmethod
    def decode(vae, z):
        """
        Decode latent to image.
        Args:
            vae: AutoencoderTiny model
            z: [B, 4, H/8, W/8] latents
        Returns:
            x: [B, 3, H, W] images in [-1, 1]
        """
        with torch.no_grad():
            x = vae.decode(z).sample
        return x
    
    @staticmethod
    def get_latent_shape(image_size):
        """Get latent spatial size (8x compression)."""
        return image_size // 8


class SDVAEWrapper:
    """
    Wrapper for Stability AI VAE (sd-vae-ft-mse).
    
    Key: Uses .latent_dist.sample() and scaling_factor=0.18215.
    
    Model: stabilityai/sd-vae-ft-mse (~84M params)
    """
    
    @staticmethod
    def load(device='cpu'):
        """Load SD VAE model."""
        from diffusers import AutoencoderKL
        model = AutoencoderKL.from_pretrained(
            "stabilityai/sd-vae-ft-mse",
            torch_dtype=torch.float32,
        )
        model = model.to(device)
        model.eval()
        return model
    
    @staticmethod
    def encode(vae, x):
        """Encode image to latent (with scaling)."""
        with torch.no_grad():
            posterior = vae.encode(x).latent_dist
            z = posterior.sample()
            z = z * vae.config.scaling_factor
        return z
    
    @staticmethod
    def decode(vae, z):
        """Decode latent to image (with unscaling)."""
        with torch.no_grad():
            z = z / vae.config.scaling_factor
            x = vae.decode(z).sample
        return x