Update main.py
Browse files
main.py
CHANGED
|
@@ -107,6 +107,7 @@ def _train(config, logger, tokenizer,
|
|
| 107 |
and utils.fsspec_exists(
|
| 108 |
config.checkpointing.resume_ckpt_path)):
|
| 109 |
ckpt_path = config.checkpointing.resume_ckpt_path
|
|
|
|
| 110 |
else:
|
| 111 |
ckpt_path = None
|
| 112 |
|
|
@@ -120,7 +121,6 @@ def _train(config, logger, tokenizer,
|
|
| 120 |
# config, tokenizer)
|
| 121 |
train_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/train')
|
| 122 |
val_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/val')
|
| 123 |
-
test_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/test')
|
| 124 |
|
| 125 |
data_module = dataloader.CustomDataModule(train_dataset, val_dataset, test_dataset, tokenizer, config, batch_size=config.loader.batch_size)
|
| 126 |
train_ds = data_module.train_dataloader()
|
|
@@ -236,6 +236,32 @@ def _ppl_eval(config, tokenizer):
|
|
| 236 |
ppl = eval_utils.compute_ppl(pretrained, valid_ds)
|
| 237 |
print(f"PPL: {ppl:0.3f}")
|
| 238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
@hydra.main(version_base=None, config_path='configs',
|
| 241 |
config_name='config')
|
|
@@ -254,6 +280,8 @@ def main(config):
|
|
| 254 |
elif 'train' in config.mode:
|
| 255 |
_train(config, logger, tokenizer,
|
| 256 |
train_classifier='classifier' in config.mode)
|
|
|
|
|
|
|
| 257 |
else:
|
| 258 |
raise NotImplementedError(f"Mode {config.mode} not implemented.")
|
| 259 |
|
|
|
|
| 107 |
and utils.fsspec_exists(
|
| 108 |
config.checkpointing.resume_ckpt_path)):
|
| 109 |
ckpt_path = config.checkpointing.resume_ckpt_path
|
| 110 |
+
print(f"CKPT PATH: {ckpt_path}")
|
| 111 |
else:
|
| 112 |
ckpt_path = None
|
| 113 |
|
|
|
|
| 121 |
# config, tokenizer)
|
| 122 |
train_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/train')
|
| 123 |
val_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/val')
|
|
|
|
| 124 |
|
| 125 |
data_module = dataloader.CustomDataModule(train_dataset, val_dataset, test_dataset, tokenizer, config, batch_size=config.loader.batch_size)
|
| 126 |
train_ds = data_module.train_dataloader()
|
|
|
|
| 236 |
ppl = eval_utils.compute_ppl(pretrained, valid_ds)
|
| 237 |
print(f"PPL: {ppl:0.3f}")
|
| 238 |
|
| 239 |
+
def _test(config, logger, tokenizer):
|
| 240 |
+
|
| 241 |
+
test_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/test')
|
| 242 |
+
data_module = dataloader.CustomDataModule(None, None, test_dataset=test_dataset, tokenizer=tokenizer, config=config, batch_size=config.loader.batch_size)
|
| 243 |
+
test_ds = data_module.test_dataloader()
|
| 244 |
+
|
| 245 |
+
model = diffusion.Diffusion.load_from_checkpoint(config.eval.checkpoint_path, tokenizer=tokenizer, config=config, logger=False)
|
| 246 |
+
model.eval()
|
| 247 |
+
|
| 248 |
+
# Create a test trainer (without training)
|
| 249 |
+
trainer = hydra.utils.instantiate(
|
| 250 |
+
config.trainer,
|
| 251 |
+
default_root_dir=os.getcwd(),
|
| 252 |
+
# logger=wandb_logger,
|
| 253 |
+
strategy=hydra.utils.instantiate(config.strategy),
|
| 254 |
+
callbacks=[] # No need for callbacks during testing
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Test the model
|
| 258 |
+
results = trainer.test(model, test_ds)
|
| 259 |
+
|
| 260 |
+
# Log or print test results
|
| 261 |
+
print(f"Test results: {results}")
|
| 262 |
+
|
| 263 |
+
return results
|
| 264 |
+
|
| 265 |
|
| 266 |
@hydra.main(version_base=None, config_path='configs',
|
| 267 |
config_name='config')
|
|
|
|
| 280 |
elif 'train' in config.mode:
|
| 281 |
_train(config, logger, tokenizer,
|
| 282 |
train_classifier='classifier' in config.mode)
|
| 283 |
+
elif 'test' in config.mode:
|
| 284 |
+
_test(config, logger, tokenizer)
|
| 285 |
else:
|
| 286 |
raise NotImplementedError(f"Mode {config.mode} not implemented.")
|
| 287 |
|