| import shutil |
| import tempfile |
| from pathlib import Path |
| from typing import Optional |
|
|
| from huggingface_hub import HfApi |
| from torch import nn |
|
|
| try: |
| from base_models.supported_models import all_presets_with_paths |
| from probes.hybrid_probe import HybridProbe |
| from probes.packaged_probe_model import PackagedProbeConfig, PackagedProbeModel |
| from utils import print_message |
| except ImportError: |
| from ..base_models.supported_models import all_presets_with_paths |
| from .hybrid_probe import HybridProbe |
| from .packaged_probe_model import PackagedProbeConfig, PackagedProbeModel |
| from ..utils import print_message |
|
|
|
|
| def _infer_probe_type(probe_model: nn.Module) -> str: |
| probe_class_name = probe_model.__class__.__name__ |
| if probe_class_name == "LinearProbe": |
| return "linear" |
| if probe_class_name in ["TransformerForSequenceClassification", "TransformerForTokenClassification"]: |
| return "transformer" |
| if probe_class_name in ["RetrievalNetForSequenceClassification", "RetrievalNetForTokenClassification"]: |
| return "retrievalnet" |
| if probe_class_name in ["LyraForSequenceClassification", "LyraForTokenClassification"]: |
| return "lyra" |
| raise ValueError(f"Unsupported probe class for packaged export: {probe_class_name}") |
|
|
|
|
| def _is_supported_base_model(source_model_name: str) -> bool: |
| if source_model_name not in all_presets_with_paths: |
| return False |
| model_name_l = source_model_name.lower() |
| if "random" in model_name_l: |
| return False |
| if "onehot" in model_name_l: |
| return False |
| if "vec2vec" in model_name_l: |
| return False |
| return True |
|
|
|
|
| def _extract_sep_token_id(tokenizer) -> Optional[int]: |
| try: |
| tokenizer_backend = tokenizer.tokenizer |
| except AttributeError: |
| tokenizer_backend = tokenizer |
| if tokenizer_backend.sep_token_id is not None: |
| return int(tokenizer_backend.sep_token_id) |
| if tokenizer_backend.eos_token_id is not None: |
| return int(tokenizer_backend.eos_token_id) |
| return None |
|
|
|
|
| def _copy_runtime_code(export_dir: Path) -> None: |
| repo_root = Path(__file__).resolve().parents[3] |
| src_package_dir = repo_root / "src" / "protify" |
| dst_package_dir = export_dir / "protify" |
|
|
| for src_file in src_package_dir.rglob("*.py"): |
| relative_path = src_file.relative_to(src_package_dir) |
| dst_file = dst_package_dir / relative_path |
| dst_file.parent.mkdir(parents=True, exist_ok=True) |
| shutil.copy2(src_file, dst_file) |
|
|
| packaged_model_file = Path(__file__).with_name("packaged_probe_model.py") |
| shutil.copy2(packaged_model_file, export_dir / "packaged_probe_model.py") |
|
|
|
|
| def _build_packaged_model( |
| trained_model: nn.Module, |
| source_model_name: str, |
| probe_args, |
| embedding_args, |
| tokenizer, |
| ppi: bool, |
| ) -> PackagedProbeModel: |
| if isinstance(trained_model, HybridProbe): |
| base_model = trained_model.model |
| probe_model = trained_model.probe |
| else: |
| base_model = None |
| probe_model = trained_model |
|
|
| probe_type = _infer_probe_type(probe_model) |
| probe_config_dict = probe_model.config.to_dict() |
| sep_token_id = _extract_sep_token_id(tokenizer) |
| packaged_config = PackagedProbeConfig( |
| base_model_name=source_model_name, |
| probe_type=probe_type, |
| probe_config=probe_config_dict, |
| tokenwise=probe_args.tokenwise, |
| matrix_embed=embedding_args.matrix_embed, |
| pooling_types=embedding_args.pooling_types, |
| task_type=probe_args.task_type, |
| num_labels=probe_args.num_labels, |
| ppi=ppi, |
| add_token_ids=probe_args.add_token_ids, |
| sep_token_id=sep_token_id, |
| ) |
| packaged_model = PackagedProbeModel(config=packaged_config, base_model=base_model, probe=probe_model) |
| return packaged_model.cpu() |
|
|
|
|
| def export_packaged_model_to_hub( |
| trained_model: nn.Module, |
| source_model_name: str, |
| probe_args, |
| embedding_args, |
| tokenizer, |
| repo_id: str, |
| model_card: str, |
| ppi: bool = False, |
| private: bool = True, |
| hf_token: Optional[str] = None, |
| ) -> tuple[bool, str]: |
| if not _is_supported_base_model(source_model_name): |
| return False, f"Packaged export is not supported for base model: {source_model_name}" |
|
|
| packaged_model = _build_packaged_model( |
| trained_model=trained_model, |
| source_model_name=source_model_name, |
| probe_args=probe_args, |
| embedding_args=embedding_args, |
| tokenizer=tokenizer, |
| ppi=ppi, |
| ) |
|
|
| with tempfile.TemporaryDirectory(prefix="protify_packaged_model_") as temp_dir: |
| export_dir = Path(temp_dir) |
|
|
| packaged_model.config.auto_map = { |
| "AutoConfig": "packaged_probe_model.PackagedProbeConfig", |
| "AutoModel": "packaged_probe_model.PackagedProbeModel", |
| } |
| packaged_model.config.architectures = ["PackagedProbeModel"] |
| packaged_model.save_pretrained(str(export_dir), safe_serialization=True) |
| tokenizer.save_pretrained(str(export_dir)) |
| _copy_runtime_code(export_dir) |
| readme_path = export_dir / "README.md" |
| readme_path.write_text(model_card, encoding="utf-8") |
|
|
| if hf_token is None: |
| api = HfApi() |
| else: |
| api = HfApi(token=hf_token) |
| api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True) |
| api.upload_folder( |
| repo_id=repo_id, |
| repo_type="model", |
| folder_path=str(export_dir), |
| path_in_repo="", |
| ) |
|
|
| print_message(f"Packaged model and tokenizer uploaded to Hugging Face Hub: {repo_id}") |
| return True, f"Uploaded packaged model to {repo_id}" |
|
|