| | |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | class BasicClassificationLoss(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.classification_loss = nn.CrossEntropyLoss() |
| |
|
| | def forward(self, pred_labels, gt_labels): |
| | return self.classification_loss(pred_labels, gt_labels) |
| |
|
| |
|
| | def save_model(trained_model, optimiser_used): |
| | torch.save(trained_model, 'trainedClassifier.pth') |
| | print(",") |
| | torch.save(trained_model.state_dict(), 'trainedClassifier_weights.pth') |
| | torch.save(optimiser_used, 'optimiserUsed.pth') |
| |
|
| |
|
| | def load_model_for_eval(file_path, model_type): |
| | model_template = model_type(416) |
| | model_template.load_state_dict(torch.load(file_path, weights_only=True)) |
| | model_template.eval() |
| | return model_template |
| |
|
| |
|
| | def softmax(unprocessed_logits): |
| | logits = np.array(unprocessed_logits) |
| | exponentials = np.exp(logits) |
| | softmax_arr = exponentials / sum(exponentials) |
| | return softmax_arr |
| |
|