import os import sys from typing import Any, Dict, Optional import torch from torch import nn from transformers import AutoModel, PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput try: from protify.base_models.supported_models import all_presets_with_paths from protify.pooler import Pooler from protify.probes.get_probe import rebuild_probe_from_saved_config except ImportError: current_dir = os.path.dirname(os.path.abspath(__file__)) candidate_paths = [ current_dir, os.path.dirname(current_dir), os.path.dirname(os.path.dirname(current_dir)), os.path.join(current_dir, "src"), ] for candidate in candidate_paths: if os.path.isdir(candidate) and candidate not in sys.path: sys.path.insert(0, candidate) from protify.base_models.supported_models import all_presets_with_paths from protify.pooler import Pooler from protify.probes.get_probe import rebuild_probe_from_saved_config class PackagedProbeConfig(PretrainedConfig): model_type = "packaged_probe" def __init__( self, base_model_name: str = "", probe_type: str = "linear", probe_config: Optional[Dict[str, Any]] = None, tokenwise: bool = False, matrix_embed: bool = False, pooling_types: Optional[list[str]] = None, task_type: str = "singlelabel", num_labels: int = 2, ppi: bool = False, add_token_ids: bool = False, sep_token_id: Optional[int] = None, **kwargs, ): super().__init__(**kwargs) self.base_model_name = base_model_name self.probe_type = probe_type self.probe_config = {} if probe_config is None else probe_config self.tokenwise = tokenwise self.matrix_embed = matrix_embed self.pooling_types = ["mean"] if pooling_types is None else pooling_types self.task_type = task_type self.num_labels = num_labels self.ppi = ppi self.add_token_ids = add_token_ids self.sep_token_id = sep_token_id class PackagedProbeModel(PreTrainedModel): config_class = PackagedProbeConfig base_model_prefix = "backbone" all_tied_weights_keys = {} def __init__( self, config: PackagedProbeConfig, base_model: Optional[nn.Module] = None, probe: Optional[nn.Module] = None, ): super().__init__(config) self.config = config self.backbone = self._load_base_model() if base_model is None else base_model self.probe = self._load_probe() if probe is None else probe self.pooler = Pooler(self.config.pooling_types) def _load_base_model(self) -> nn.Module: if self.config.base_model_name in all_presets_with_paths: model_path = all_presets_with_paths[self.config.base_model_name] else: model_path = self.config.base_model_name model = AutoModel.from_pretrained(model_path, trust_remote_code=True) model.eval() return model def _load_probe(self) -> nn.Module: return rebuild_probe_from_saved_config( probe_type=self.config.probe_type, tokenwise=self.config.tokenwise, probe_config=self.config.probe_config, ) @staticmethod def _extract_hidden_states(backbone_output: Any) -> torch.Tensor: if isinstance(backbone_output, tuple): return backbone_output[0] if hasattr(backbone_output, "last_hidden_state"): return backbone_output.last_hidden_state if isinstance(backbone_output, torch.Tensor): return backbone_output raise ValueError("Unsupported backbone output format for packaged probe model") @staticmethod def _extract_attentions(backbone_output: Any) -> Optional[torch.Tensor]: if hasattr(backbone_output, "attentions"): return backbone_output.attentions return None def _build_ppi_segment_masks( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: if token_type_ids is not None and torch.any(token_type_ids == 1): mask_a = ((token_type_ids == 0) & (attention_mask == 1)).long() mask_b = ((token_type_ids == 1) & (attention_mask == 1)).long() assert torch.all(mask_a.sum(dim=1) > 0), "PPI token_type_ids produced empty segment A" assert torch.all(mask_b.sum(dim=1) > 0), "PPI token_type_ids produced empty segment B" return mask_a, mask_b assert self.config.sep_token_id is not None, "sep_token_id is required for PPI fallback segmentation" batch_size, seq_len = input_ids.shape mask_a = torch.zeros((batch_size, seq_len), dtype=torch.long, device=input_ids.device) mask_b = torch.zeros((batch_size, seq_len), dtype=torch.long, device=input_ids.device) for batch_idx in range(batch_size): valid_positions = torch.where(attention_mask[batch_idx] == 1)[0] sep_positions = torch.where((input_ids[batch_idx] == self.config.sep_token_id) & (attention_mask[batch_idx] == 1))[0] if len(valid_positions) == 0: continue if len(sep_positions) >= 2: first_sep = int(sep_positions[0].item()) second_sep = int(sep_positions[1].item()) mask_a[batch_idx, :first_sep + 1] = 1 mask_b[batch_idx, first_sep + 1:second_sep + 1] = 1 elif len(sep_positions) == 1: first_sep = int(sep_positions[0].item()) mask_a[batch_idx, :first_sep + 1] = 1 mask_b[batch_idx, first_sep + 1: int(valid_positions[-1].item()) + 1] = 1 else: midpoint = len(valid_positions) // 2 mask_a[batch_idx, valid_positions[:midpoint]] = 1 mask_b[batch_idx, valid_positions[midpoint:]] = 1 assert torch.all(mask_a.sum(dim=1) > 0), "PPI fallback segmentation produced empty segment A" assert torch.all(mask_b.sum(dim=1) > 0), "PPI fallback segmentation produced empty segment B" return mask_a, mask_b def _build_probe_inputs( self, hidden_states: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: Optional[torch.Tensor], attentions: Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if self.config.ppi and (not self.config.matrix_embed) and (not self.config.tokenwise): mask_a, mask_b = self._build_ppi_segment_masks(input_ids, attention_mask, token_type_ids) vec_a = self.pooler(hidden_states, attention_mask=mask_a, attentions=attentions) vec_b = self.pooler(hidden_states, attention_mask=mask_b, attentions=attentions) return torch.cat((vec_a, vec_b), dim=-1), None if self.config.matrix_embed or self.config.tokenwise: return hidden_states, attention_mask pooled = self.pooler(hidden_states, attention_mask=attention_mask, attentions=attentions) return pooled, None def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, ) -> SequenceClassifierOutput | TokenClassifierOutput: if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.long) requires_attentions = "parti" in self.config.pooling_types and (not self.config.matrix_embed) and (not self.config.tokenwise) backbone_kwargs: Dict[str, Any] = {"input_ids": input_ids, "attention_mask": attention_mask} if requires_attentions: backbone_kwargs["output_attentions"] = True backbone_output = self.backbone(**backbone_kwargs) hidden_states = self._extract_hidden_states(backbone_output) attentions = self._extract_attentions(backbone_output) if requires_attentions: assert attentions is not None, "parti pooling requires base model attentions" probe_embeddings, probe_attention_mask = self._build_probe_inputs( hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, attentions=attentions, ) if self.config.probe_type == "linear": return self.probe(embeddings=probe_embeddings, labels=labels) if self.config.probe_type == "transformer": forward_kwargs: Dict[str, Any] = {"embeddings": probe_embeddings, "labels": labels} if probe_attention_mask is not None: forward_kwargs["attention_mask"] = probe_attention_mask if self.config.add_token_ids and token_type_ids is not None and probe_attention_mask is not None: forward_kwargs["token_type_ids"] = token_type_ids return self.probe(**forward_kwargs) if self.config.probe_type in ["retrievalnet", "lyra"]: return self.probe(embeddings=probe_embeddings, attention_mask=probe_attention_mask, labels=labels) raise ValueError(f"Unsupported probe type for packaged model: {self.config.probe_type}")