| import os |
| import pickle |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| def assign_learning_rate(param_group, new_lr): |
| param_group["lr"] = new_lr |
|
|
|
|
| def _warmup_lr(base_lr, warmup_length, step): |
| return base_lr * (step + 1) / warmup_length |
|
|
|
|
| def cosine_lr(optimizer, base_lrs, warmup_length, steps): |
| if not isinstance(base_lrs, list): |
| base_lrs = [base_lrs for _ in optimizer.param_groups] |
| assert len(base_lrs) == len(optimizer.param_groups) |
|
|
| def _lr_adjuster(step): |
| for param_group, base_lr in zip(optimizer.param_groups, base_lrs): |
| if step < warmup_length: |
| lr = _warmup_lr(base_lr, warmup_length, step) |
| else: |
| e = step - warmup_length |
| es = steps - warmup_length |
| lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr |
| assign_learning_rate(param_group, lr) |
|
|
| return _lr_adjuster |
|
|
|
|
| def accuracy(output, target, topk=(1,)): |
| pred = output.topk(max(topk), 1, True, True)[1].t() |
| correct = pred.eq(target.view(1, -1).expand_as(pred)) |
| return [ |
| float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) |
| for k in topk |
| ] |
|
|
|
|
| def torch_load_old(save_path, device=None): |
| with open(save_path, "rb") as f: |
| classifier = pickle.load(f) |
| if device is not None: |
| classifier = classifier.to(device) |
| return classifier |
|
|
|
|
| def torch_save(model, save_path): |
| if os.path.dirname(save_path) != "": |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) |
| torch.save(model, save_path) |
|
|
|
|
| def torch_load(save_path, device=None): |
| model = torch.load(save_path, map_location="cpu") |
| if device is not None: |
| model = model.to(device) |
| return model |
|
|
|
|
| def get_logits(inputs, classifier): |
| assert callable(classifier) |
| if hasattr(classifier, "to"): |
| classifier = classifier.to(inputs.device) |
| return classifier(inputs) |
|
|
|
|
| def get_probs(inputs, classifier): |
| if hasattr(classifier, "predict_proba"): |
| probs = classifier.predict_proba(inputs.detach().cpu().numpy()) |
| return torch.from_numpy(probs) |
| logits = get_logits(inputs, classifier) |
| return logits.softmax(dim=1) |
|
|
|
|
| class LabelSmoothing(torch.nn.Module): |
| def __init__(self, smoothing=0.0): |
| super(LabelSmoothing, self).__init__() |
| self.confidence = 1.0 - smoothing |
| self.smoothing = smoothing |
|
|
| def forward(self, x, target): |
| logprobs = torch.nn.functional.log_softmax(x, dim=-1) |
|
|
| nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) |
| nll_loss = nll_loss.squeeze(1) |
| smooth_loss = -logprobs.mean(dim=-1) |
| loss = self.confidence * nll_loss + self.smoothing * smooth_loss |
| return loss.mean() |
|
|
|
|
| class DotDict(dict): |
| """dot.notation access to dictionary attributes""" |
|
|
| __getattr__ = dict.get |
| __setattr__ = dict.__setitem__ |
| __delattr__ = dict.__delitem__ |
|
|
|
|
| def find_optimal_coef( |
| results, |
| metric="avg_normalized_top1", |
| minimize=False, |
| control_metric=None, |
| control_metric_threshold=0.0, |
| ): |
| best_coef = None |
| if minimize: |
| best_metric = 1 |
| else: |
| best_metric = 0 |
| for scaling_coef in results.keys(): |
| if control_metric is not None: |
| if results[scaling_coef][control_metric] < control_metric_threshold: |
| print(f"Control metric fell below {control_metric_threshold} threshold") |
| continue |
| if minimize: |
| if results[scaling_coef][metric] < best_metric: |
| best_metric = results[scaling_coef][metric] |
| best_coef = scaling_coef |
| else: |
| if results[scaling_coef][metric] > best_metric: |
| best_metric = results[scaling_coef][metric] |
| best_coef = scaling_coef |
| return best_coef |
|
|
|
|
| def nonlinear_advantage(nonlinear_acc, linear_acc, num_classes): |
| return (nonlinear_acc - linear_acc) / (1.0 - 1.0 / num_classes) |
|
|
|
|
| def calculate_linearized_orthogonality_loss(linearized_model): |
| """Compute orthogonality loss ||delta_W^T delta_W - I||_F for a LinearizedModel.""" |
| ortho_loss = 0.0 |
| for p_finetuned, p_pretrained in zip(linearized_model.params, linearized_model.params0): |
| if p_finetuned.requires_grad and p_finetuned.dim() == 2: |
| delta_W = p_finetuned - p_pretrained |
|
|
| rows, cols = delta_W.shape |
| if rows < cols: |
| mat = delta_W @ delta_W.T |
| identity = torch.eye(rows, device=delta_W.device) |
| else: |
| mat = delta_W.T @ delta_W |
| identity = torch.eye(cols, device=delta_W.device) |
|
|
| ortho_loss += torch.norm(mat - identity, p='fro') |
|
|
| return ortho_loss |
|
|
|
|
| def calculate_standard_orthogonality_loss(model, pretrained_state_dict): |
| """Compute orthogonality loss ||delta_W^T delta_W - I||_F for standard/linear-2 finetuning. |
| |
| Args: |
| model: DDP-wrapped ImageClassifier (ddp_model). |
| pretrained_state_dict: snapshot of the pretrained model's inner ViT state_dict. |
| """ |
| ortho_loss = 0.0 |
|
|
| for name, p_finetuned in model.module.image_encoder.model.named_parameters(): |
| if p_finetuned.requires_grad and p_finetuned.dim() == 2: |
| if name in pretrained_state_dict: |
| p_pretrained = pretrained_state_dict[name].to(p_finetuned.device) |
|
|
| delta_W = p_finetuned - p_pretrained |
|
|
| rows, cols = delta_W.shape |
| if rows < cols: |
| mat = delta_W @ delta_W.T |
| identity = torch.eye(rows, device=delta_W.device) |
| else: |
| mat = delta_W.T @ delta_W |
| identity = torch.eye(cols, device=delta_W.device) |
|
|
| ortho_loss += torch.norm(mat - identity, p='fro') |
|
|
| return ortho_loss |
|
|