| from accelerate import Accelerator | |
| from datasets.core import TrajectoryDataset | |
| class Workspace: | |
| def __init__(self, cfg, work_dir): | |
| self.cfg = cfg | |
| self.work_dir = work_dir | |
| self.accelerator = Accelerator() | |
| self.dataset: TrajectoryDataset = None | |
| def set_models(self, encoder, projector): | |
| self.encoder = encoder | |
| self.projector = projector | |
| def set_dataset(self, dataset): | |
| self.dataset = dataset | |
| def run_offline_eval(self): | |
| return {"loss": 0} | |