| import shutil | |
| import tempfile | |
| import gc | |
| from pathlib import Path | |
| import torch | |
| from transformers import AutoModel, BertConfig, BertModel, BertTokenizerFast | |
| try: | |
| from probes.linear_probe import LinearProbe, LinearProbeConfig | |
| from probes.packaged_probe_model import PackagedProbeConfig, PackagedProbeModel | |
| from probes.transformer_probe import TransformerForSequenceClassification, TransformerProbeConfig | |
| except ImportError: | |
| from ..probes.linear_probe import LinearProbe, LinearProbeConfig | |
| from ..probes.packaged_probe_model import PackagedProbeConfig, PackagedProbeModel | |
| from ..probes.transformer_probe import TransformerForSequenceClassification, TransformerProbeConfig | |
| def _copy_runtime_code(save_dir: Path) -> None: | |
| repo_root = Path(__file__).resolve().parents[3] | |
| src_package_dir = repo_root / "src" / "protify" | |
| dst_package_dir = save_dir / "protify" | |
| for src_file in src_package_dir.rglob("*.py"): | |
| relative_path = src_file.relative_to(src_package_dir) | |
| dst_file = dst_package_dir / relative_path | |
| dst_file.parent.mkdir(parents=True, exist_ok=True) | |
| shutil.copy2(src_file, dst_file) | |
| packaged_model_file = repo_root / "src" / "protify" / "probes" / "packaged_probe_model.py" | |
| shutil.copy2(packaged_model_file, save_dir / "packaged_probe_model.py") | |
| def _create_tiny_backbone(backbone_dir: Path) -> tuple[BertModel, BertTokenizerFast]: | |
| vocab_tokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "A", "B", "C", "D"] | |
| vocab_path = backbone_dir / "vocab.txt" | |
| vocab_path.write_text("\n".join(vocab_tokens), encoding="utf-8") | |
| tokenizer = BertTokenizerFast(vocab_file=str(vocab_path), do_lower_case=False) | |
| config = BertConfig( | |
| vocab_size=len(vocab_tokens), | |
| hidden_size=16, | |
| num_hidden_layers=2, | |
| num_attention_heads=2, | |
| intermediate_size=32, | |
| ) | |
| model = BertModel(config).eval() | |
| model.save_pretrained(str(backbone_dir)) | |
| tokenizer.save_pretrained(str(backbone_dir)) | |
| return model, tokenizer | |
| def _save_and_load_with_automodel( | |
| packaged_model: PackagedProbeModel, | |
| tokenizer: BertTokenizerFast, | |
| model_dir: Path, | |
| ) -> AutoModel: | |
| packaged_model.config.auto_map = { | |
| "AutoConfig": "packaged_probe_model.PackagedProbeConfig", | |
| "AutoModel": "packaged_probe_model.PackagedProbeModel", | |
| } | |
| packaged_model.config.architectures = ["PackagedProbeModel"] | |
| packaged_model.save_pretrained(str(model_dir), safe_serialization=True) | |
| tokenizer.save_pretrained(str(model_dir)) | |
| _copy_runtime_code(model_dir) | |
| return AutoModel.from_pretrained(str(model_dir), trust_remote_code=True) | |
| def test_linear_packaged_roundtrip() -> None: | |
| with tempfile.TemporaryDirectory(prefix="protify_linear_packaged_test_", ignore_cleanup_errors=True) as temp_dir: | |
| temp_path = Path(temp_dir) | |
| backbone_dir = temp_path / "backbone" | |
| model_dir = temp_path / "linear_packaged_model" | |
| backbone_dir.mkdir(parents=True, exist_ok=True) | |
| model_dir.mkdir(parents=True, exist_ok=True) | |
| backbone, tokenizer = _create_tiny_backbone(backbone_dir) | |
| probe_config = LinearProbeConfig( | |
| input_size=16, | |
| hidden_size=32, | |
| dropout=0.1, | |
| num_labels=3, | |
| n_layers=1, | |
| task_type="singlelabel", | |
| ) | |
| probe = LinearProbe(probe_config).eval() | |
| packaged_config = PackagedProbeConfig( | |
| base_model_name=str(backbone_dir), | |
| probe_type="linear", | |
| probe_config=probe.config.to_dict(), | |
| tokenwise=False, | |
| matrix_embed=False, | |
| pooling_types=["mean"], | |
| task_type="singlelabel", | |
| num_labels=3, | |
| ppi=False, | |
| add_token_ids=False, | |
| sep_token_id=tokenizer.sep_token_id, | |
| ) | |
| packaged_model = PackagedProbeModel(config=packaged_config, base_model=backbone, probe=probe).eval() | |
| loaded_model = _save_and_load_with_automodel(packaged_model, tokenizer, model_dir) | |
| batch = tokenizer(["A B C A", "B C D A"], padding="longest", return_tensors="pt") | |
| outputs = loaded_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) | |
| assert outputs.logits.shape == (2, 3), f"Unexpected linear packaged logits shape: {outputs.logits.shape}" | |
| del loaded_model | |
| gc.collect() | |
| def test_transformer_packaged_roundtrip() -> None: | |
| with tempfile.TemporaryDirectory(prefix="protify_transformer_packaged_test_", ignore_cleanup_errors=True) as temp_dir: | |
| temp_path = Path(temp_dir) | |
| backbone_dir = temp_path / "backbone" | |
| model_dir = temp_path / "transformer_packaged_model" | |
| backbone_dir.mkdir(parents=True, exist_ok=True) | |
| model_dir.mkdir(parents=True, exist_ok=True) | |
| backbone, tokenizer = _create_tiny_backbone(backbone_dir) | |
| probe_config = TransformerProbeConfig( | |
| input_size=16, | |
| hidden_size=16, | |
| classifier_size=24, | |
| transformer_dropout=0.1, | |
| classifier_dropout=0.1, | |
| num_labels=2, | |
| n_layers=1, | |
| token_attention=False, | |
| n_heads=2, | |
| task_type="singlelabel", | |
| rotary=False, | |
| pre_ln=True, | |
| probe_pooling_types=["mean"], | |
| use_bias=False, | |
| add_token_ids=False, | |
| ) | |
| probe = TransformerForSequenceClassification(probe_config).eval() | |
| packaged_config = PackagedProbeConfig( | |
| base_model_name=str(backbone_dir), | |
| probe_type="transformer", | |
| probe_config=probe.config.to_dict(), | |
| tokenwise=False, | |
| matrix_embed=True, | |
| pooling_types=["mean"], | |
| task_type="singlelabel", | |
| num_labels=2, | |
| ppi=False, | |
| add_token_ids=False, | |
| sep_token_id=tokenizer.sep_token_id, | |
| ) | |
| packaged_model = PackagedProbeModel(config=packaged_config, base_model=backbone, probe=probe).eval() | |
| loaded_model = _save_and_load_with_automodel(packaged_model, tokenizer, model_dir) | |
| batch = tokenizer(["A B C D", "D C B A"], padding="longest", return_tensors="pt") | |
| outputs = loaded_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) | |
| assert outputs.logits.shape == (2, 2), f"Unexpected transformer packaged logits shape: {outputs.logits.shape}" | |
| del loaded_model | |
| gc.collect() | |
| def test_ppi_packaged_inference_with_and_without_token_type_ids() -> None: | |
| with tempfile.TemporaryDirectory(prefix="protify_ppi_packaged_test_", ignore_cleanup_errors=True) as temp_dir: | |
| temp_path = Path(temp_dir) | |
| backbone_dir = temp_path / "backbone" | |
| model_dir = temp_path / "ppi_packaged_model" | |
| backbone_dir.mkdir(parents=True, exist_ok=True) | |
| model_dir.mkdir(parents=True, exist_ok=True) | |
| backbone, tokenizer = _create_tiny_backbone(backbone_dir) | |
| probe_config = LinearProbeConfig( | |
| input_size=32, | |
| hidden_size=24, | |
| dropout=0.1, | |
| num_labels=2, | |
| n_layers=1, | |
| task_type="singlelabel", | |
| ) | |
| probe = LinearProbe(probe_config).eval() | |
| packaged_config = PackagedProbeConfig( | |
| base_model_name=str(backbone_dir), | |
| probe_type="linear", | |
| probe_config=probe.config.to_dict(), | |
| tokenwise=False, | |
| matrix_embed=False, | |
| pooling_types=["mean"], | |
| task_type="singlelabel", | |
| num_labels=2, | |
| ppi=True, | |
| add_token_ids=False, | |
| sep_token_id=tokenizer.sep_token_id, | |
| ) | |
| packaged_model = PackagedProbeModel(config=packaged_config, base_model=backbone, probe=probe).eval() | |
| loaded_model = _save_and_load_with_automodel(packaged_model, tokenizer, model_dir) | |
| pair_batch = tokenizer( | |
| ["A B C", "B C D"], | |
| ["D C B", "A C B"], | |
| padding="longest", | |
| return_tensors="pt", | |
| ) | |
| outputs_with_token_types = loaded_model( | |
| input_ids=pair_batch["input_ids"], | |
| attention_mask=pair_batch["attention_mask"], | |
| token_type_ids=pair_batch["token_type_ids"], | |
| ) | |
| assert outputs_with_token_types.logits.shape == (2, 2), "PPI logits shape mismatch with token_type_ids" | |
| outputs_without_token_types = loaded_model( | |
| input_ids=pair_batch["input_ids"], | |
| attention_mask=pair_batch["attention_mask"], | |
| ) | |
| assert outputs_without_token_types.logits.shape == (2, 2), "PPI logits shape mismatch without token_type_ids" | |
| del loaded_model | |
| gc.collect() | |
| def main() -> None: | |
| torch.manual_seed(0) | |
| test_linear_packaged_roundtrip() | |
| test_transformer_packaged_roundtrip() | |
| test_ppi_packaged_inference_with_and_without_token_type_ids() | |
| print("Packaged probe model smoke tests passed.") | |
| if __name__ == "__main__": | |
| main() | |