import abc import torch from src.linearize import LinearizedImageEncoder from src.modeling import ImageEncoder from src.attention_only_finetune import AttentionOnlyFinetuneEncoder class _TaskVector(abc.ABC): def __init__( self, pretrained_checkpoint=None, finetuned_checkpoint=None, vector=None ): """ Initializes the task vector from a pretrained and a finetuned checkpoints. This can either be done by passing two state dicts (one corresponding to the pretrained model, and another to the finetuned model), or by directly passing in the task vector state dict. """ if vector is not None: self.vector = vector else: assert ( pretrained_checkpoint is not None and finetuned_checkpoint is not None ) with torch.no_grad(): pretrained_obj = self._load_checkpoint(pretrained_checkpoint) finetuned_obj = self._load_checkpoint(finetuned_checkpoint) if hasattr(pretrained_obj, 'state_dict'): pretrained_state_dict = pretrained_obj.state_dict() else: pretrained_state_dict = pretrained_obj if hasattr(finetuned_obj, 'state_dict'): finetuned_state_dict = finetuned_obj.state_dict() else: finetuned_state_dict = finetuned_obj self.vector = {} for key in pretrained_state_dict: if pretrained_state_dict[key].dtype not in [torch.float32, torch.float16, torch.bfloat16]: continue if key in finetuned_state_dict: self.vector[key] = ( finetuned_state_dict[key] - pretrained_state_dict[key] ) @abc.abstractmethod def _load_checkpoint(self, checkpoint): raise NotImplementedError @abc.abstractmethod def _cast_to_same_type(self, other): raise NotImplementedError def __add__(self, other): other = self._cast_to_same_type(other) with torch.no_grad(): new_vector = {} for key in self.vector: if key not in other.vector: print(f"Warning, key {key} is not present in both task vectors.") continue new_vector[key] = self.vector[key] + other.vector[key] return self.__class__(vector=new_vector) def __sub__(self, other): return self.__add__(-other) def __radd__(self, other): if other is None or isinstance(other, int): return self return self.__add__(other) def __neg__(self): with torch.no_grad(): new_vector = {} for key in self.vector: new_vector[key] = -self.vector[key] return self.__class__(vector=new_vector) def __pow__(self, power): with torch.no_grad(): new_vector = {} for key in self.vector: new_vector[key] = self.vector[key] ** power return self.__class__(vector=new_vector) def __mul__(self, other): with torch.no_grad(): new_vector = {} for key in self.vector: new_vector[key] = other * self.vector[key] return self.__class__(vector=new_vector) def dot(self, other): other = self._cast_to_same_type(other) with torch.no_grad(): dot_product = 0.0 for key in self.vector: if key not in other.vector: print(f"Warning, key {key} is not present in both task vectors.") continue dot_product += torch.sum(self.vector[key] * other.vector[key]) return dot_product def norm(self): return torch.sqrt(self.dot(self)) def apply_to(self, pretrained_checkpoint, scaling_coef=1.0): """Apply a task vector to a pretrained model.""" with torch.no_grad(): pretrained_model = self._load_checkpoint(pretrained_checkpoint) if hasattr(pretrained_model, 'state_dict'): new_state_dict = pretrained_model.state_dict() else: new_state_dict = pretrained_model.copy() pretrained_state_dict = new_state_dict.copy() for key in pretrained_state_dict: if key in self.vector: new_state_dict[key] = ( pretrained_state_dict[key] + scaling_coef * self.vector[key] ) if hasattr(pretrained_model, 'state_dict'): pretrained_model.load_state_dict(new_state_dict) return pretrained_model else: from src.args import parse_arguments args = parse_arguments() if isinstance(self, NonLinearTaskVector): encoder = self._build_model_from_checkpoint(pretrained_checkpoint, args) encoder.load_state_dict(new_state_dict) return encoder else: pretrained_model.load_state_dict(new_state_dict) return pretrained_model class NonLinearTaskVector(_TaskVector): """A task vector for nonlinear models.""" def _load_checkpoint(self, checkpoint): return torch.load(checkpoint, map_location="cpu") def _build_model_from_checkpoint(self, checkpoint_path, args): mode = args.finetuning_mode if mode in ["linear-2", "linear-2_ortho"]: return AttentionOnlyFinetuneEncoder(args) return ImageEncoder(args) def apply_to(self, pretrained_checkpoint, scaling_coef=1.0): with torch.no_grad(): from src.args import parse_arguments args = parse_arguments() pretrained_model = self._build_model_from_checkpoint(pretrained_checkpoint, args) pretrained_state_dict = torch.load(pretrained_checkpoint, map_location='cpu') if hasattr(pretrained_state_dict, 'state_dict'): pretrained_state_dict = pretrained_state_dict.state_dict() new_state_dict = pretrained_state_dict.copy() for key in pretrained_state_dict: if key in self.vector: new_state_dict[key] += scaling_coef * self.vector[key] pretrained_model.load_state_dict(new_state_dict) return pretrained_model def _cast_to_same_type(self, other): if isinstance(other, LinearizedTaskVector): return linear_to_nonlinear(other, self.vector.keys()) return other class LinearizedTaskVector(_TaskVector): """A task vector for linearized models.""" def _load_checkpoint(self, checkpoint): return LinearizedImageEncoder.load(checkpoint) def apply_to(self, pretrained_checkpoint, scaling_coef=1.0): with torch.no_grad(): pretrained_model = self._load_checkpoint(pretrained_checkpoint) new_state_dict = pretrained_model.state_dict() pretrained_state_dict = new_state_dict.copy() for key in pretrained_state_dict: if key in self.vector: new_state_dict[key] += scaling_coef * self.vector[key] pretrained_model.load_state_dict(new_state_dict) return pretrained_model def get_named_parameters(self, param_names): params = {k: v for k, v in self.vector.items() if "model.params0" not in k} return {k: v for k, v in zip(param_names, params.values())} def _cast_to_same_type(self, other): if isinstance(other, NonLinearTaskVector): return nonlinear_to_linear(other) return other def nonlinear_to_linear(nonlinear_task_vector): if isinstance(nonlinear_task_vector, LinearizedTaskVector): return nonlinear_task_vector else: linear_params = { f"model.params.{i}": v for i, v in enumerate(nonlinear_task_vector.vector.values()) } linear_params.update({ f"model.params0.{i}": torch.zeros_like(v) for i, v in enumerate(nonlinear_task_vector.vector.values()) }) return LinearizedTaskVector(vector=linear_params) def linear_to_nonlinear(linear_task_vector, param_names): if isinstance(linear_task_vector, NonLinearTaskVector): return linear_task_vector else: return NonLinearTaskVector( vector=linear_task_vector.get_named_parameters(param_names) )