| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| from torch import nn, Tensor |
|
|
|
|
| def _bf16_u16(x: Tensor) -> Tensor: |
| |
| return x.contiguous().view(torch.int16).to(torch.int32) & 0xFFFF |
|
|
|
|
| class CachedDenoiseStepEmb(nn.Module): |
| """bf16 sigma -> bf16 embedding via 64k LUT; invalid sigma => OOB index error (no silent wrong).""" |
|
|
| def __init__(self, base: nn.Module, sigmas: list[float]): |
| super().__init__() |
| device = next(base.parameters()).device |
|
|
| levels = torch.tensor(sigmas, device=device, dtype=torch.bfloat16) |
| bits = _bf16_u16(levels) |
| if torch.unique(bits).numel() != bits.numel(): |
| raise ValueError( |
| "scheduler_sigmas collide in bf16; caching would be ambiguous" |
| ) |
|
|
| with torch.no_grad(): |
| table = ( |
| base(levels[:, None]).squeeze(1).to(torch.bfloat16).contiguous() |
| ) |
|
|
| lut = torch.full((65536,), -1, device=device, dtype=torch.int32) |
| lut[bits] = torch.arange(bits.numel(), device=device, dtype=torch.int32) |
|
|
| self.register_buffer("table", table, persistent=False) |
| self.register_buffer("lut", lut, persistent=False) |
| self.register_buffer( |
| "oob", |
| torch.tensor(bits.numel(), device=device, dtype=torch.int32), |
| persistent=False, |
| ) |
|
|
| def forward(self, sigma: Tensor) -> Tensor: |
| if sigma.dtype is not torch.bfloat16: |
| raise RuntimeError("CachedDenoiseStepEmb expects sigma bf16") |
| idx = self.lut[_bf16_u16(sigma)] |
| idx = torch.where(idx >= 0, idx, self.oob) |
| return self.table[idx.to(torch.int64)] |
|
|
|
|
| class CachedCondHead(nn.Module): |
| """bf16 cond -> cached (s0,b0,g0,s1,b1,g1); invalid cond => OOB index error (no silent wrong).""" |
|
|
| def __init__( |
| self, base, cached_denoise_step_emb: CachedDenoiseStepEmb, max_key_dims: int = 8 |
| ): |
| super().__init__() |
| table = cached_denoise_step_emb.table |
| S, D = table.shape |
|
|
| with torch.no_grad(): |
| emb = table[:, None, :] |
| cache = ( |
| torch.stack([t.squeeze(1) for t in base(emb)], 0) |
| .to(torch.bfloat16) |
| .contiguous() |
| ) |
|
|
| |
| key_dim = None |
| for d in range(min(D, max_key_dims)): |
| b = _bf16_u16(table[:, d]) |
| if torch.unique(b).numel() == S: |
| key_dim = d |
| key_bits = b |
| break |
| if key_dim is None: |
| raise ValueError( |
| "Could not find a unique bf16 key dim for cond->sigma mapping; increase max_key_dims" |
| ) |
|
|
| lut = torch.full((65536,), -1, device=table.device, dtype=torch.int32) |
| lut[key_bits] = torch.arange(S, device=table.device, dtype=torch.int32) |
|
|
| self.key_dim = int(key_dim) |
| self.register_buffer("cache", cache, persistent=False) |
| self.register_buffer("lut", lut, persistent=False) |
| self.register_buffer( |
| "oob", |
| torch.tensor(S, device=table.device, dtype=torch.int32), |
| persistent=False, |
| ) |
|
|
| def forward(self, cond: Tensor): |
| if cond.dtype is not torch.bfloat16: |
| raise RuntimeError("CachedCondHead expects cond bf16") |
| idx = self.lut[_bf16_u16(cond[..., self.key_dim])] |
| idx = torch.where(idx >= 0, idx, self.oob) |
| g = self.cache[:, idx.to(torch.int64)] |
| return tuple(g.unbind(0)) |
|
|