OrthoReg / src /utils.py
gezi2333's picture
Upload folder using huggingface_hub
3589275 verified
import os
import pickle
import numpy as np
import torch
def assign_learning_rate(param_group, new_lr):
param_group["lr"] = new_lr
def _warmup_lr(base_lr, warmup_length, step):
return base_lr * (step + 1) / warmup_length
def cosine_lr(optimizer, base_lrs, warmup_length, steps):
if not isinstance(base_lrs, list):
base_lrs = [base_lrs for _ in optimizer.param_groups]
assert len(base_lrs) == len(optimizer.param_groups)
def _lr_adjuster(step):
for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
e = step - warmup_length
es = steps - warmup_length
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
assign_learning_rate(param_group, lr)
return _lr_adjuster
def accuracy(output, target, topk=(1,)):
pred = output.topk(max(topk), 1, True, True)[1].t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
return [
float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
for k in topk
]
def torch_load_old(save_path, device=None):
with open(save_path, "rb") as f:
classifier = pickle.load(f)
if device is not None:
classifier = classifier.to(device)
return classifier
def torch_save(model, save_path):
if os.path.dirname(save_path) != "":
os.makedirs(os.path.dirname(save_path), exist_ok=True)
torch.save(model, save_path)
def torch_load(save_path, device=None):
model = torch.load(save_path, map_location="cpu")
if device is not None:
model = model.to(device)
return model
def get_logits(inputs, classifier):
assert callable(classifier)
if hasattr(classifier, "to"):
classifier = classifier.to(inputs.device)
return classifier(inputs)
def get_probs(inputs, classifier):
if hasattr(classifier, "predict_proba"):
probs = classifier.predict_proba(inputs.detach().cpu().numpy())
return torch.from_numpy(probs)
logits = get_logits(inputs, classifier)
return logits.softmax(dim=1)
class LabelSmoothing(torch.nn.Module):
def __init__(self, smoothing=0.0):
super(LabelSmoothing, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
def forward(self, x, target):
logprobs = torch.nn.functional.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
class DotDict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def find_optimal_coef(
results,
metric="avg_normalized_top1",
minimize=False,
control_metric=None,
control_metric_threshold=0.0,
):
best_coef = None
if minimize:
best_metric = 1
else:
best_metric = 0
for scaling_coef in results.keys():
if control_metric is not None:
if results[scaling_coef][control_metric] < control_metric_threshold:
print(f"Control metric fell below {control_metric_threshold} threshold")
continue
if minimize:
if results[scaling_coef][metric] < best_metric:
best_metric = results[scaling_coef][metric]
best_coef = scaling_coef
else:
if results[scaling_coef][metric] > best_metric:
best_metric = results[scaling_coef][metric]
best_coef = scaling_coef
return best_coef
def nonlinear_advantage(nonlinear_acc, linear_acc, num_classes):
return (nonlinear_acc - linear_acc) / (1.0 - 1.0 / num_classes)
def calculate_linearized_orthogonality_loss(linearized_model):
"""Compute orthogonality loss ||delta_W^T delta_W - I||_F for a LinearizedModel."""
ortho_loss = 0.0
for p_finetuned, p_pretrained in zip(linearized_model.params, linearized_model.params0):
if p_finetuned.requires_grad and p_finetuned.dim() == 2:
delta_W = p_finetuned - p_pretrained
rows, cols = delta_W.shape
if rows < cols:
mat = delta_W @ delta_W.T
identity = torch.eye(rows, device=delta_W.device)
else:
mat = delta_W.T @ delta_W
identity = torch.eye(cols, device=delta_W.device)
ortho_loss += torch.norm(mat - identity, p='fro')
return ortho_loss
def calculate_standard_orthogonality_loss(model, pretrained_state_dict):
"""Compute orthogonality loss ||delta_W^T delta_W - I||_F for standard/linear-2 finetuning.
Args:
model: DDP-wrapped ImageClassifier (ddp_model).
pretrained_state_dict: snapshot of the pretrained model's inner ViT state_dict.
"""
ortho_loss = 0.0
for name, p_finetuned in model.module.image_encoder.model.named_parameters():
if p_finetuned.requires_grad and p_finetuned.dim() == 2:
if name in pretrained_state_dict:
p_pretrained = pretrained_state_dict[name].to(p_finetuned.device)
delta_W = p_finetuned - p_pretrained
rows, cols = delta_W.shape
if rows < cols:
mat = delta_W @ delta_W.T
identity = torch.eye(rows, device=delta_W.device)
else:
mat = delta_W.T @ delta_W
identity = torch.eye(cols, device=delta_W.device)
ortho_loss += torch.norm(mat - identity, p='fro')
return ortho_loss