| import abc |
|
|
| import torch |
| import torch.nn as nn |
| from typing import Tuple, Dict |
|
|
|
|
| class AbstractSSL(nn.Module): |
| """ |
| This class should contain everything inside the SSL method (e.g. key queue for MoCo, EMA for BYOL, etc.), loss function, and the optimizer. |
| """ |
|
|
| @abc.abstractmethod |
| def __init__( |
| self, |
| encoder: nn.Module, |
| projector: nn.Module, |
| ): |
| """ |
| Initializes the SSL method. |
| Inputs: |
| encoder: the encoder module |
| projector: the projector module |
| """ |
| raise NotImplementedError |
|
|
| @abc.abstractmethod |
| def forward( |
| self, |
| obs: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor],]: |
| """ |
| Inputs: |
| obs: the input observations |
| Outputs: |
| obs_enc: the encoded observations |
| obs_proj: the projected observations |
| loss: the total loss |
| loss_components: the components of the total loss |
| """ |
| raise NotImplementedError |
|
|
| def step(self): |
| """ |
| This function should be called at each training step to update the SSL method's internal state. |
| e.g. step the optimizer, update the key queue for MoCo, update EMA for BYOL, etc. |
| """ |
| raise NotImplementedError |
|
|