| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from typing import Optional |
| |
|
| | import torch |
| | from torch import nn |
| |
|
| |
|
| | class ClsToken(nn.Module): |
| | def __init__(self, ndim: int, |
| | num_tokens: int = 1, |
| | enabled: bool = True, |
| | register_multiple: Optional[int] = None, |
| | num_registers: Optional[int] = None, |
| | ): |
| | super().__init__() |
| |
|
| | self.ndim = ndim |
| | self.enabled = enabled |
| | self.num_registers = 0 |
| | self.num_tokens = num_tokens |
| | if enabled: |
| | if num_registers: |
| | self.num_registers = num_registers |
| | elif register_multiple: |
| | self.num_registers = register_multiple - (num_tokens % register_multiple) |
| |
|
| | scale = ndim ** -0.5 |
| | self.token = nn.Parameter(torch.randn(num_tokens + self.num_registers, ndim) * scale) |
| | else: |
| | self.token = None |
| |
|
| | self.num_patches = self.num_tokens + self.num_registers |
| |
|
| | def disable(self): |
| | self.token = None |
| | self.enabled = False |
| |
|
| | def forward(self, x: torch.Tensor): |
| | if self.token is None: |
| | return x |
| |
|
| | token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1) |
| | x = torch.cat([ |
| | token, |
| | x, |
| | ], dim=1) |
| |
|
| | return x |
| |
|
| | def no_weight_decay(self): |
| | return [ |
| | 'token', |
| | ] |
| |
|