OrthoReg / src /modeling.py
gezi2333's picture
Upload folder using huggingface_hub
3589275 verified
import open_clip
import torch
from src import utils
class ImageEncoder(torch.nn.Module):
def __init__(self, args, keep_lang=False):
super().__init__()
print(f"Loading {args.model} pre-trained weights.")
if "__pretrained__" in args.model:
name, pretrained = args.model.split("__pretrained__")
elif "__init__" in args.model:
print("Using random initialization.")
name, pretrained = args.model.split("__init__")[0], None
else:
name = args.model
pretrained = "openai"
(
self.model,
self.train_preprocess,
self.val_preprocess,
) = open_clip.create_model_and_transforms(
name, pretrained=pretrained, cache_dir=args.openclip_cachedir
)
self.cache_dir = args.cache_dir
if not keep_lang and hasattr(self.model, "transformer"):
delattr(self.model, "transformer")
# def forward(self, images):
# assert self.model is not None
# return self.model.encode_image(images)
# def __call__(self, inputs):
# return self.forward(inputs)
def forward(self, images, calculate_ortho_loss=False, pretrained_state_dict=None):
"""
Extended forward method to optionally compute and return the orthogonal loss.
"""
# Original forward pass
features = self.model.encode_image(images)
# Return features directly if orthogonal loss is not needed
if not calculate_ortho_loss:
return features
# --- Compute orthogonal loss if requested ---
# This logic is moved here from utils.py
if pretrained_state_dict is None:
raise ValueError("pretrained_state_dict must be provided when calculate_ortho_loss is True")
ortho_loss = 0.0
# self.model is the open_clip model (e.g. ViT); iterate over its parameters
for name, p_finetuned in self.model.named_parameters():
if p_finetuned.requires_grad and p_finetuned.dim() == 2:
if name in pretrained_state_dict:
p_pretrained = pretrained_state_dict[name].to(p_finetuned.device)
delta_W = p_finetuned - p_pretrained
rows, cols = delta_W.shape
if rows < cols:
mat = delta_W @ delta_W.T
identity = torch.eye(rows, device=delta_W.device)
else:
mat = delta_W.T @ delta_W
identity = torch.eye(cols, device=delta_W.device)
ortho_loss += torch.norm(mat - identity, p='fro')
return features, ortho_loss
def __call__(self, inputs, calculate_ortho_loss=False, pretrained_state_dict=None):
# Ensure __call__ forwards all arguments
return self.forward(inputs, calculate_ortho_loss, pretrained_state_dict)
def save(self, filename):
print(f"Saving image encoder to {filename}")
utils.torch_save(self, filename)
@classmethod
def load(cls, model_name, filename):
print(f"Loading image encoder from {filename}")
state_dict = torch.load(filename, map_location="cpu")
return cls.load(model_name, state_dict)
@classmethod
def load_from_state_dict(cls, model_name, state_dict):
(
self.model,
self.train_preprocess,
self.val_preprocess,
) = open_clip.create_model_and_transforms(
name, pretrained=pretrained, cache_dir=args.openclip_cachedir
)
self.model.load_from_state_dict(state_dict)
class ClassificationHead(torch.nn.Linear):
def __init__(self, normalize, weights, biases=None):
output_size, input_size = weights.shape
super().__init__(input_size, output_size)
self.normalize = normalize
if weights is not None:
self.weight = torch.nn.Parameter(weights.clone())
if biases is not None:
self.bias = torch.nn.Parameter(biases.clone())
else:
self.bias = torch.nn.Parameter(torch.zeros_like(self.bias))
def forward(self, inputs):
if self.normalize:
inputs = inputs / inputs.norm(dim=-1, keepdim=True)
return super().forward(inputs)
def __call__(self, inputs):
return self.forward(inputs)
def save(self, filename):
print(f"Saving classification head to {filename}")
utils.torch_save(self, filename)
@classmethod
def load(cls, filename):
print(f"Loading classification head from {filename}")
return utils.torch_load(filename)
class ImageClassifier(torch.nn.Module):
def __init__(self, image_encoder, classification_head):
super().__init__()
self.image_encoder = image_encoder
self.classification_head = classification_head
if self.image_encoder is not None:
self.train_preprocess = self.image_encoder.train_preprocess
self.val_preprocess = self.image_encoder.val_preprocess
def freeze_head(self):
self.classification_head.weight.requires_grad_(False)
self.classification_head.bias.requires_grad_(False)
# def forward(self, inputs):
# features = self.image_encoder(inputs)
# outputs = self.classification_head(features)
# return outputs
# def __call__(self, inputs):
# return self.forward(inputs)
def forward(self, inputs, calculate_ortho_loss=False, pretrained_state_dict=None):
# Forward arguments to image_encoder
encoder_output = self.image_encoder(inputs, calculate_ortho_loss, pretrained_state_dict)
if calculate_ortho_loss:
features, ortho_loss = encoder_output
outputs = self.classification_head(features)
return outputs, ortho_loss
else:
features = encoder_output
outputs = self.classification_head(features)
return outputs
def __call__(self, inputs, calculate_ortho_loss=False, pretrained_state_dict=None):
return self.forward(inputs, calculate_ortho_loss, pretrained_state_dict)
def save(self, filename):
print(f"Saving image classifier to {filename}")
utils.torch_save(self, filename)
@classmethod
def load(cls, filename):
print(f"Loading image classifier from {filename}")
return utils.torch_load(filename)
class MultiHeadImageClassifier(torch.nn.Module):
def __init__(self, image_encoder, classification_heads):
super().__init__()
self.image_encoder = image_encoder
self.classification_heads = torch.nn.ModuleList(classification_heads)
if self.image_encoder is not None:
self.train_preprocess = self.image_encoder.train_preprocess
self.val_preprocess = self.image_encoder.val_preprocess
def freeze_head(self):
for idx in range(len(self.classification_heads)):
self.classification_heads[idx].weight.requires_grad_(False)
self.classification_heads[idx].bias.requires_grad_(False)
def forward(self, inputs, head_idx):
features = self.image_encoder(inputs)
outputs = self.classification_heads[head_idx](features)
return outputs
def __call__(self, inputs, head_idx):
return self.forward(inputs, head_idx)
def save(self, filename):
print(f"Saving image classifier to {filename}")
utils.torch_save(self, filename)
@classmethod
def load(cls, filename):
print(f"Loading image classifier from {filename}")
return utils.torch_load(filename)