| import torch |
| import torch.nn as nn |
|
|
| from model.encoder import Encoder |
| from model.decoder import Decoder |
|
|
| from model.utils import weight_init |
|
|
|
|
| class Trainer(nn.Module): |
| def __init__(self, model_type='small'): |
| super().__init__() |
| if model_type == 'tiny': |
| embed_dim = 192 |
| elif model_type == 'small': |
| embed_dim = 384 |
| else: |
| assert False, r'Trainer: check the vit model type' |
|
|
| self.encoder = Encoder(model_type) |
|
|
| self.decoder = Decoder(in_dim=[64, 128, 256, embed_dim]) |
| weight_init(self.decoder) |
| |
| def forward(self, x, y): |
| fx, fy = self.encoder(x, y) |
| pred = self.decoder(fx, fy) |
|
|
| return pred |
| |