nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import copy
import os
import torch
from huggingface_hub import HfApi, login
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoConfig
from esm_plusplus.load_official import load_official_model
from esm_plusplus.modeling_esm_plusplus import ESMplusplusForMaskedLM
from weight_parity_utils import assert_state_dict_equal, assert_model_parameters_fp32
MODEL_DICT = {
"Synthyra/ESMplusplus_small": "esmc-300",
"Synthyra/ESMplusplus_large": "esmc-600",
}
def _resolve_repo_items(repo_ids: list[str] | None) -> list[tuple[str, str]]:
if repo_ids is None or len(repo_ids) == 0:
return list(MODEL_DICT.items())
selected_items: list[tuple[str, str]] = []
for repo_id in repo_ids:
assert repo_id in MODEL_DICT, (
f"Unknown repo_id {repo_id}. "
f"Valid options: {sorted(MODEL_DICT.keys())}"
)
selected_items.append((repo_id, MODEL_DICT[repo_id]))
return selected_items
if __name__ == "__main__":
# py -m esm_plusplus.get_esmc_weights
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--hf_token", type=str, default=None)
parser.add_argument("--repo_ids", nargs="*", type=str, default=None)
parser.add_argument("--dry_run", action="store_true")
parser.add_argument("--skip-weights", action="store_true")
args = parser.parse_args()
api = HfApi()
if args.hf_token is not None:
assert len(args.hf_token) > 0, "--hf_token cannot be empty."
login(token=args.hf_token)
script_root = os.path.dirname(os.path.abspath(__file__))
for repo_id, esmc_model_key in _resolve_repo_items(args.repo_ids):
official_model, tokenizer = load_official_model(esmc_model_key, device=torch.device("cpu"), dtype=torch.float32)
# load_official_model returns a wrapper, access the underlying model via .model
config = AutoConfig.from_pretrained(repo_id, trust_remote_code=True)
config.auto_map = {
"AutoConfig": "modeling_esm_plusplus.ESMplusplusConfig",
"AutoModel": "modeling_esm_plusplus.ESMplusplusModel",
"AutoModelForMaskedLM": "modeling_esm_plusplus.ESMplusplusForMaskedLM",
"AutoModelForSequenceClassification": "modeling_esm_plusplus.ESMplusplusForSequenceClassification",
"AutoModelForTokenClassification": "modeling_esm_plusplus.ESMplusplusForTokenClassification",
}
config.tie_word_embeddings = False
if args.skip_weights:
if args.dry_run:
print(f"[skip-weights][dry-run] validated config+tokenizer parity for {repo_id}")
continue
config.push_to_hub(repo_id)
tokenizer.push_to_hub(repo_id)
print(f"[skip-weights] uploaded config+tokenizer for {repo_id}")
continue
model = ESMplusplusForMaskedLM(config=config).eval().cpu().to(torch.float32)
load_result = model.load_state_dict(official_model.model.state_dict(), strict=True)
# Manually load sequence head to prevent weight tying issues
model.sequence_head[0].weight = copy.deepcopy(official_model.model.sequence_head[0].weight)
model.sequence_head[0].bias = copy.deepcopy(official_model.model.sequence_head[0].bias)
model.sequence_head[2].weight = copy.deepcopy(official_model.model.sequence_head[2].weight)
model.sequence_head[2].bias = copy.deepcopy(official_model.model.sequence_head[2].bias)
model.sequence_head[3].weight = copy.deepcopy(official_model.model.sequence_head[3].weight)
model.sequence_head[3].bias = copy.deepcopy(official_model.model.sequence_head[3].bias)
assert_model_parameters_fp32(
model=model,
model_name=f"mapped ESM++ model ({esmc_model_key})",
)
assert_model_parameters_fp32(
model=official_model.model,
model_name=f"official ESM++ model ({esmc_model_key})",
)
assert_state_dict_equal(
reference_state_dict=official_model.model.state_dict(),
candidate_state_dict=model.state_dict(),
context=f"ESMC/ESM++ weight parity ({esmc_model_key})",
)
if args.dry_run:
print(f"[dry_run] validated ESM++ parity for {repo_id} <- {esmc_model_key}")
continue
tokenizer.push_to_hub(repo_id)
model.push_to_hub(repo_id)
api.upload_file(
path_or_fileobj="esm_plusplus/modeling_esm_plusplus.py",
path_in_repo="modeling_esm_plusplus.py",
repo_id=repo_id,
repo_type="model",
)
downloaded_model = AutoModelForMaskedLM.from_pretrained(
repo_id,
dtype=torch.float32,
device_map="cpu",
force_download=True,
trust_remote_code=True,
)
assert_state_dict_equal(
reference_state_dict=official_model.model.state_dict(),
candidate_state_dict=downloaded_model.state_dict(),
context=f"ESMC/ESM++ weight parity post-download ({repo_id})",
)