| import glob |
| import os |
| from collections import OrderedDict |
|
|
| import torch |
|
|
|
|
| class Saver(object): |
| def __init__(self, args): |
| self.args = args |
| self.directory = os.path.join("run", args.train_dataset, args.checkname) |
| self.runs = sorted(glob.glob(os.path.join(self.directory, "experiment_*"))) |
| run_id = int(self.runs[-1].split("_")[-1]) + 1 if self.runs else 0 |
|
|
| self.experiment_dir = os.path.join( |
| self.directory, "experiment_{}".format(str(run_id)) |
| ) |
| if not os.path.exists(self.experiment_dir): |
| os.makedirs(self.experiment_dir) |
|
|
| def save_checkpoint(self, state, filename="checkpoint.pth.tar"): |
| """Saves checkpoint to disk""" |
| filename = os.path.join(self.experiment_dir, filename) |
| torch.save(state, filename) |
|
|
| def save_experiment_config(self): |
| logfile = os.path.join(self.experiment_dir, "parameters.txt") |
| log_file = open(logfile, "w") |
| p = OrderedDict() |
| p["train_dataset"] = self.args.train_dataset |
| p["lr"] = self.args.lr |
| p["epoch"] = self.args.epochs |
|
|
| for key, val in p.items(): |
| log_file.write(key + ":" + str(val) + "\n") |
| log_file.close() |
|
|