| import torch |
| import os |
| import torch.nn.functional as F |
|
|
| def save_checkpoint(model, filelocation, save_parallel = True): |
| if save_parallel: |
| torch.save(model.module.state_dict(), filelocation) |
| else: |
| torch.save(model.state_dict(), filelocation) |
|
|
| def load_Checkpoint(fileLocation,model, load_cpu=False): |
| if load_cpu: |
| model.load_state_dict(torch.load(fileLocation,map_location=lambda storage, loc: storage)) |
| else: |
| model.load_state_dict(torch.load(fileLocation)) |
| return model |
|
|
| def writeLog(logList, filename): |
| with open(filename, 'w') as outfile: |
| outfile.write("\n".join(logList)) |
|
|
|
|
| def kl_loss(mu, logvar): |
| return -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean() |
|
|
|
|
|
|