File size: 1,025 Bytes
ca5da2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
from diffusers import ModelMixin, ConfigMixin
from diffusers.configuration_utils import register_to_config


class ConditionalEmbedder(ModelMixin, ConfigMixin):
    """
    Patchifies VAE-encoded conditions (source video or reference image)
    into the DiT hidden dimension space via a Conv3d layer.
    """

    @register_to_config
    def __init__(
        self,
        in_dim: int = 48,
        dim: int = 3072,
        patch_size: list = [1, 2, 2],
        zero_init: bool = True,
        ref_pad_first: bool = False,
    ):
        super().__init__()
        kernel_size = tuple(patch_size)
        self.patch_embedding = nn.Conv3d(
            in_dim, dim, kernel_size=kernel_size, stride=kernel_size
        )
        self.ref_pad_first = ref_pad_first
        if zero_init:
            nn.init.zeros_(self.patch_embedding.weight)
            nn.init.zeros_(self.patch_embedding.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.patch_embedding(x)