import shutil import tempfile from pathlib import Path from typing import Optional from huggingface_hub import HfApi from torch import nn try: from base_models.supported_models import all_presets_with_paths from probes.hybrid_probe import HybridProbe from probes.packaged_probe_model import PackagedProbeConfig, PackagedProbeModel from utils import print_message except ImportError: from ..base_models.supported_models import all_presets_with_paths from .hybrid_probe import HybridProbe from .packaged_probe_model import PackagedProbeConfig, PackagedProbeModel from ..utils import print_message def _infer_probe_type(probe_model: nn.Module) -> str: probe_class_name = probe_model.__class__.__name__ if probe_class_name == "LinearProbe": return "linear" if probe_class_name in ["TransformerForSequenceClassification", "TransformerForTokenClassification"]: return "transformer" if probe_class_name in ["RetrievalNetForSequenceClassification", "RetrievalNetForTokenClassification"]: return "retrievalnet" if probe_class_name in ["LyraForSequenceClassification", "LyraForTokenClassification"]: return "lyra" raise ValueError(f"Unsupported probe class for packaged export: {probe_class_name}") def _is_supported_base_model(source_model_name: str) -> bool: if source_model_name not in all_presets_with_paths: return False model_name_l = source_model_name.lower() if "random" in model_name_l: return False if "onehot" in model_name_l: return False if "vec2vec" in model_name_l: return False return True def _extract_sep_token_id(tokenizer) -> Optional[int]: try: tokenizer_backend = tokenizer.tokenizer except AttributeError: tokenizer_backend = tokenizer if tokenizer_backend.sep_token_id is not None: return int(tokenizer_backend.sep_token_id) if tokenizer_backend.eos_token_id is not None: return int(tokenizer_backend.eos_token_id) return None def _copy_runtime_code(export_dir: Path) -> None: repo_root = Path(__file__).resolve().parents[3] src_package_dir = repo_root / "src" / "protify" dst_package_dir = export_dir / "protify" for src_file in src_package_dir.rglob("*.py"): relative_path = src_file.relative_to(src_package_dir) dst_file = dst_package_dir / relative_path dst_file.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(src_file, dst_file) packaged_model_file = Path(__file__).with_name("packaged_probe_model.py") shutil.copy2(packaged_model_file, export_dir / "packaged_probe_model.py") def _build_packaged_model( trained_model: nn.Module, source_model_name: str, probe_args, embedding_args, tokenizer, ppi: bool, ) -> PackagedProbeModel: if isinstance(trained_model, HybridProbe): base_model = trained_model.model probe_model = trained_model.probe else: base_model = None probe_model = trained_model probe_type = _infer_probe_type(probe_model) probe_config_dict = probe_model.config.to_dict() sep_token_id = _extract_sep_token_id(tokenizer) packaged_config = PackagedProbeConfig( base_model_name=source_model_name, probe_type=probe_type, probe_config=probe_config_dict, tokenwise=probe_args.tokenwise, matrix_embed=embedding_args.matrix_embed, pooling_types=embedding_args.pooling_types, task_type=probe_args.task_type, num_labels=probe_args.num_labels, ppi=ppi, add_token_ids=probe_args.add_token_ids, sep_token_id=sep_token_id, ) packaged_model = PackagedProbeModel(config=packaged_config, base_model=base_model, probe=probe_model) return packaged_model.cpu() def export_packaged_model_to_hub( trained_model: nn.Module, source_model_name: str, probe_args, embedding_args, tokenizer, repo_id: str, model_card: str, ppi: bool = False, private: bool = True, hf_token: Optional[str] = None, ) -> tuple[bool, str]: if not _is_supported_base_model(source_model_name): return False, f"Packaged export is not supported for base model: {source_model_name}" packaged_model = _build_packaged_model( trained_model=trained_model, source_model_name=source_model_name, probe_args=probe_args, embedding_args=embedding_args, tokenizer=tokenizer, ppi=ppi, ) with tempfile.TemporaryDirectory(prefix="protify_packaged_model_") as temp_dir: export_dir = Path(temp_dir) packaged_model.config.auto_map = { "AutoConfig": "packaged_probe_model.PackagedProbeConfig", "AutoModel": "packaged_probe_model.PackagedProbeModel", } packaged_model.config.architectures = ["PackagedProbeModel"] packaged_model.save_pretrained(str(export_dir), safe_serialization=True) tokenizer.save_pretrained(str(export_dir)) _copy_runtime_code(export_dir) readme_path = export_dir / "README.md" readme_path.write_text(model_card, encoding="utf-8") if hf_token is None: api = HfApi() else: api = HfApi(token=hf_token) api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True) api.upload_folder( repo_id=repo_id, repo_type="model", folder_path=str(export_dir), path_in_repo="", ) print_message(f"Packaged model and tokenizer uploaded to Hugging Face Hub: {repo_id}") return True, f"Uploaded packaged model to {repo_id}"