| import torch |
| import torch.nn as nn |
| from timm.layers import PatchEmbed as TimmPatchEmbed |
|
|
|
|
| class PatchEmbed(nn.Module): |
| """ |
| 2D Image to Patch Embedding. |
| |
| Args: |
| img_size (tuple[int, int]): Input image size (H, W). |
| patch_size (tuple[int, int]): Patch size (H, W). |
| in_chans (int): Number of input channels. |
| embed_dim (int): Embedding dimension. |
| """ |
|
|
| def __init__( |
| self, |
| img_size: tuple[int, int] = (128, 256), |
| patch_size: tuple[int, int] = (16, 16), |
| in_chans: int = 1, |
| embed_dim: int = 768, |
| ): |
| super().__init__() |
| self.img_size = img_size |
| self.patch_size = patch_size |
| self.in_chans = in_chans |
| self.embed_dim = embed_dim |
|
|
| self.patch_embed = TimmPatchEmbed( |
| img_size=img_size, |
| patch_size=patch_size, |
| in_chans=in_chans, |
| embed_dim=embed_dim, |
| flatten=True, |
| bias=True, |
| strict_img_size=False, |
| ) |
| self.num_patches = self.patch_embed.num_patches |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Forward pass. |
| |
| Args: |
| x (torch.Tensor): Input tensor [B, C, H, W]. |
| |
| Returns: |
| torch.Tensor: Patch embeddings [B, N, D]. |
| """ |
| return self.patch_embed(x) |
|
|