File size: 4,815 Bytes
62dca4c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | import glob
import json
import os
from typing import Optional
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import AutoConfig
class TargetEmbeddingsAndHead(nn.Module):
"""
Efficiently loads only the embedding layer and lm_head from a pretrained model.
Avoids loading the full model into memory.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
@classmethod
def from_pretrained(
cls,
model_path: str,
embed_key: str = "model.embed_tokens.weight",
lm_head_key: str = "lm_head.weight",
cache_dir: Optional[str] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
trust_remote_code: bool = False,
) -> "TargetEmbeddingsAndHead":
# 1. Load Config
config = AutoConfig.from_pretrained(
model_path, cache_dir=cache_dir, trust_remote_code=trust_remote_code
)
instance = cls(config)
# 2. Resolve Model Path (Handle Hub)
local_model_path = model_path
if not os.path.exists(local_model_path):
try:
local_model_path = snapshot_download(
repo_id=model_path, cache_dir=cache_dir
)
except:
pass # Maybe it's a local path that looks like a repo ID but doesn't exist?
# 3. Load Weights Efficiently
instance._load_weights(local_model_path, embed_key, lm_head_key)
# 4. Move to Device & Freeze
instance.to(device=device, dtype=dtype)
instance.eval()
instance.requires_grad_(False)
return instance
def _load_weights(self, model_path: str, embed_key: str, lm_head_key: str):
# Locate index.json
index_files = glob.glob(os.path.join(model_path, "*.index.json"))
weight_map = {}
if index_files:
# Sharded Checkpoint
with open(index_files[0], "r") as f:
index = json.load(f)
# Find which file contains our keys
weight_map = index.get("weight_map", {})
files_to_load = {}
if embed_key in weight_map:
files_to_load[embed_key] = weight_map[embed_key]
else:
# Fallback: sometimes keys are prefixed differently?
print(
f"Warning: {embed_key} not found in weight_map. Keys available: {list(weight_map.keys())[:5]}..."
)
if lm_head_key in weight_map:
files_to_load[lm_head_key] = weight_map[lm_head_key]
# Load specific files
for key, filename in files_to_load.items():
file_path = os.path.join(model_path, filename)
self._load_key_from_file(file_path, key)
else:
# Non-sharded Checkpoint (single file)
# Try finding .safetensors or .bin
safetensors = glob.glob(os.path.join(model_path, "*.safetensors"))
bins = glob.glob(os.path.join(model_path, "*.bin"))
target_file = None
if safetensors:
target_file = safetensors[0]
elif bins:
target_file = bins[0]
if target_file:
self._load_key_from_file(target_file, embed_key)
self._load_key_from_file(target_file, lm_head_key)
else:
raise FileNotFoundError(f"No checkpoint file found in {model_path}")
def _load_key_from_file(self, file_path: str, key: str):
tensor = None
if file_path.endswith(".safetensors"):
with safe_open(file_path, framework="pt") as f:
if key in f.keys():
tensor = f.get_tensor(key)
else:
# torch.load loads full dict, less efficient but works
state_dict = torch.load(file_path, map_location="cpu")
if key in state_dict:
tensor = state_dict[key]
del state_dict # Free immediately
if tensor is not None:
if key.endswith("embed_tokens.weight"):
self.embed_tokens.weight.data.copy_(tensor)
print(f"Loaded embedding weights from {file_path}")
elif key.endswith("lm_head.weight"):
self.lm_head.weight.data.copy_(tensor)
print(f"Loaded lm_head weights from {file_path}")
else:
print(f"Warning: Key {key} not found in {file_path}")
|