import shutil import tempfile import gc from pathlib import Path import torch from transformers import AutoModel, BertConfig, BertModel, BertTokenizerFast try: from probes.linear_probe import LinearProbe, LinearProbeConfig from probes.packaged_probe_model import PackagedProbeConfig, PackagedProbeModel from probes.transformer_probe import TransformerForSequenceClassification, TransformerProbeConfig except ImportError: from ..probes.linear_probe import LinearProbe, LinearProbeConfig from ..probes.packaged_probe_model import PackagedProbeConfig, PackagedProbeModel from ..probes.transformer_probe import TransformerForSequenceClassification, TransformerProbeConfig def _copy_runtime_code(save_dir: Path) -> None: repo_root = Path(__file__).resolve().parents[3] src_package_dir = repo_root / "src" / "protify" dst_package_dir = save_dir / "protify" for src_file in src_package_dir.rglob("*.py"): relative_path = src_file.relative_to(src_package_dir) dst_file = dst_package_dir / relative_path dst_file.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(src_file, dst_file) packaged_model_file = repo_root / "src" / "protify" / "probes" / "packaged_probe_model.py" shutil.copy2(packaged_model_file, save_dir / "packaged_probe_model.py") def _create_tiny_backbone(backbone_dir: Path) -> tuple[BertModel, BertTokenizerFast]: vocab_tokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "A", "B", "C", "D"] vocab_path = backbone_dir / "vocab.txt" vocab_path.write_text("\n".join(vocab_tokens), encoding="utf-8") tokenizer = BertTokenizerFast(vocab_file=str(vocab_path), do_lower_case=False) config = BertConfig( vocab_size=len(vocab_tokens), hidden_size=16, num_hidden_layers=2, num_attention_heads=2, intermediate_size=32, ) model = BertModel(config).eval() model.save_pretrained(str(backbone_dir)) tokenizer.save_pretrained(str(backbone_dir)) return model, tokenizer def _save_and_load_with_automodel( packaged_model: PackagedProbeModel, tokenizer: BertTokenizerFast, model_dir: Path, ) -> AutoModel: packaged_model.config.auto_map = { "AutoConfig": "packaged_probe_model.PackagedProbeConfig", "AutoModel": "packaged_probe_model.PackagedProbeModel", } packaged_model.config.architectures = ["PackagedProbeModel"] packaged_model.save_pretrained(str(model_dir), safe_serialization=True) tokenizer.save_pretrained(str(model_dir)) _copy_runtime_code(model_dir) return AutoModel.from_pretrained(str(model_dir), trust_remote_code=True) def test_linear_packaged_roundtrip() -> None: with tempfile.TemporaryDirectory(prefix="protify_linear_packaged_test_", ignore_cleanup_errors=True) as temp_dir: temp_path = Path(temp_dir) backbone_dir = temp_path / "backbone" model_dir = temp_path / "linear_packaged_model" backbone_dir.mkdir(parents=True, exist_ok=True) model_dir.mkdir(parents=True, exist_ok=True) backbone, tokenizer = _create_tiny_backbone(backbone_dir) probe_config = LinearProbeConfig( input_size=16, hidden_size=32, dropout=0.1, num_labels=3, n_layers=1, task_type="singlelabel", ) probe = LinearProbe(probe_config).eval() packaged_config = PackagedProbeConfig( base_model_name=str(backbone_dir), probe_type="linear", probe_config=probe.config.to_dict(), tokenwise=False, matrix_embed=False, pooling_types=["mean"], task_type="singlelabel", num_labels=3, ppi=False, add_token_ids=False, sep_token_id=tokenizer.sep_token_id, ) packaged_model = PackagedProbeModel(config=packaged_config, base_model=backbone, probe=probe).eval() loaded_model = _save_and_load_with_automodel(packaged_model, tokenizer, model_dir) batch = tokenizer(["A B C A", "B C D A"], padding="longest", return_tensors="pt") outputs = loaded_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) assert outputs.logits.shape == (2, 3), f"Unexpected linear packaged logits shape: {outputs.logits.shape}" del loaded_model gc.collect() def test_transformer_packaged_roundtrip() -> None: with tempfile.TemporaryDirectory(prefix="protify_transformer_packaged_test_", ignore_cleanup_errors=True) as temp_dir: temp_path = Path(temp_dir) backbone_dir = temp_path / "backbone" model_dir = temp_path / "transformer_packaged_model" backbone_dir.mkdir(parents=True, exist_ok=True) model_dir.mkdir(parents=True, exist_ok=True) backbone, tokenizer = _create_tiny_backbone(backbone_dir) probe_config = TransformerProbeConfig( input_size=16, hidden_size=16, classifier_size=24, transformer_dropout=0.1, classifier_dropout=0.1, num_labels=2, n_layers=1, token_attention=False, n_heads=2, task_type="singlelabel", rotary=False, pre_ln=True, probe_pooling_types=["mean"], use_bias=False, add_token_ids=False, ) probe = TransformerForSequenceClassification(probe_config).eval() packaged_config = PackagedProbeConfig( base_model_name=str(backbone_dir), probe_type="transformer", probe_config=probe.config.to_dict(), tokenwise=False, matrix_embed=True, pooling_types=["mean"], task_type="singlelabel", num_labels=2, ppi=False, add_token_ids=False, sep_token_id=tokenizer.sep_token_id, ) packaged_model = PackagedProbeModel(config=packaged_config, base_model=backbone, probe=probe).eval() loaded_model = _save_and_load_with_automodel(packaged_model, tokenizer, model_dir) batch = tokenizer(["A B C D", "D C B A"], padding="longest", return_tensors="pt") outputs = loaded_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) assert outputs.logits.shape == (2, 2), f"Unexpected transformer packaged logits shape: {outputs.logits.shape}" del loaded_model gc.collect() def test_ppi_packaged_inference_with_and_without_token_type_ids() -> None: with tempfile.TemporaryDirectory(prefix="protify_ppi_packaged_test_", ignore_cleanup_errors=True) as temp_dir: temp_path = Path(temp_dir) backbone_dir = temp_path / "backbone" model_dir = temp_path / "ppi_packaged_model" backbone_dir.mkdir(parents=True, exist_ok=True) model_dir.mkdir(parents=True, exist_ok=True) backbone, tokenizer = _create_tiny_backbone(backbone_dir) probe_config = LinearProbeConfig( input_size=32, hidden_size=24, dropout=0.1, num_labels=2, n_layers=1, task_type="singlelabel", ) probe = LinearProbe(probe_config).eval() packaged_config = PackagedProbeConfig( base_model_name=str(backbone_dir), probe_type="linear", probe_config=probe.config.to_dict(), tokenwise=False, matrix_embed=False, pooling_types=["mean"], task_type="singlelabel", num_labels=2, ppi=True, add_token_ids=False, sep_token_id=tokenizer.sep_token_id, ) packaged_model = PackagedProbeModel(config=packaged_config, base_model=backbone, probe=probe).eval() loaded_model = _save_and_load_with_automodel(packaged_model, tokenizer, model_dir) pair_batch = tokenizer( ["A B C", "B C D"], ["D C B", "A C B"], padding="longest", return_tensors="pt", ) outputs_with_token_types = loaded_model( input_ids=pair_batch["input_ids"], attention_mask=pair_batch["attention_mask"], token_type_ids=pair_batch["token_type_ids"], ) assert outputs_with_token_types.logits.shape == (2, 2), "PPI logits shape mismatch with token_type_ids" outputs_without_token_types = loaded_model( input_ids=pair_batch["input_ids"], attention_mask=pair_batch["attention_mask"], ) assert outputs_without_token_types.logits.shape == (2, 2), "PPI logits shape mismatch without token_type_ids" del loaded_model gc.collect() def main() -> None: torch.manual_seed(0) test_linear_packaged_roundtrip() test_transformer_packaged_roundtrip() test_ppi_packaged_inference_with_and_without_token_type_ids() print("Packaged probe model smoke tests passed.") if __name__ == "__main__": main()