AlienChen commited on
Commit
f96153b
·
verified ·
1 Parent(s): f095528

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +29 -1
main.py CHANGED
@@ -107,6 +107,7 @@ def _train(config, logger, tokenizer,
107
  and utils.fsspec_exists(
108
  config.checkpointing.resume_ckpt_path)):
109
  ckpt_path = config.checkpointing.resume_ckpt_path
 
110
  else:
111
  ckpt_path = None
112
 
@@ -120,7 +121,6 @@ def _train(config, logger, tokenizer,
120
  # config, tokenizer)
121
  train_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/train')
122
  val_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/val')
123
- test_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/test')
124
 
125
  data_module = dataloader.CustomDataModule(train_dataset, val_dataset, test_dataset, tokenizer, config, batch_size=config.loader.batch_size)
126
  train_ds = data_module.train_dataloader()
@@ -236,6 +236,32 @@ def _ppl_eval(config, tokenizer):
236
  ppl = eval_utils.compute_ppl(pretrained, valid_ds)
237
  print(f"PPL: {ppl:0.3f}")
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
  @hydra.main(version_base=None, config_path='configs',
241
  config_name='config')
@@ -254,6 +280,8 @@ def main(config):
254
  elif 'train' in config.mode:
255
  _train(config, logger, tokenizer,
256
  train_classifier='classifier' in config.mode)
 
 
257
  else:
258
  raise NotImplementedError(f"Mode {config.mode} not implemented.")
259
 
 
107
  and utils.fsspec_exists(
108
  config.checkpointing.resume_ckpt_path)):
109
  ckpt_path = config.checkpointing.resume_ckpt_path
110
+ print(f"CKPT PATH: {ckpt_path}")
111
  else:
112
  ckpt_path = None
113
 
 
121
  # config, tokenizer)
122
  train_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/train')
123
  val_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/val')
 
124
 
125
  data_module = dataloader.CustomDataModule(train_dataset, val_dataset, test_dataset, tokenizer, config, batch_size=config.loader.batch_size)
126
  train_ds = data_module.train_dataloader()
 
236
  ppl = eval_utils.compute_ppl(pretrained, valid_ds)
237
  print(f"PPL: {ppl:0.3f}")
238
 
239
+ def _test(config, logger, tokenizer):
240
+
241
+ test_dataset = load_from_disk('/home/tc415/discrete-diffusion-guidance/dataset/3000_400k/test')
242
+ data_module = dataloader.CustomDataModule(None, None, test_dataset=test_dataset, tokenizer=tokenizer, config=config, batch_size=config.loader.batch_size)
243
+ test_ds = data_module.test_dataloader()
244
+
245
+ model = diffusion.Diffusion.load_from_checkpoint(config.eval.checkpoint_path, tokenizer=tokenizer, config=config, logger=False)
246
+ model.eval()
247
+
248
+ # Create a test trainer (without training)
249
+ trainer = hydra.utils.instantiate(
250
+ config.trainer,
251
+ default_root_dir=os.getcwd(),
252
+ # logger=wandb_logger,
253
+ strategy=hydra.utils.instantiate(config.strategy),
254
+ callbacks=[] # No need for callbacks during testing
255
+ )
256
+
257
+ # Test the model
258
+ results = trainer.test(model, test_ds)
259
+
260
+ # Log or print test results
261
+ print(f"Test results: {results}")
262
+
263
+ return results
264
+
265
 
266
  @hydra.main(version_base=None, config_path='configs',
267
  config_name='config')
 
280
  elif 'train' in config.mode:
281
  _train(config, logger, tokenizer,
282
  train_classifier='classifier' in config.mode)
283
+ elif 'test' in config.mode:
284
+ _test(config, logger, tokenizer)
285
  else:
286
  raise NotImplementedError(f"Mode {config.mode} not implemented.")
287