| | """ |
| | Helion-OSC Sharded Model Loader |
| | Efficiently loads 116 safetensors shards (2.8GB each) |
| | """ |
| |
|
| | import torch |
| | import json |
| | import os |
| | from pathlib import Path |
| | from typing import Dict, Optional, List |
| | import logging |
| | from tqdm import tqdm |
| | from safetensors.torch import load_file |
| | from transformers import AutoConfig, AutoTokenizer |
| | import psutil |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class ShardedModelLoader: |
| | """ |
| | Loader for sharded safetensors model files |
| | Optimized for 116 shards of 2.8GB each |
| | """ |
| | |
| | def __init__(self, model_path: str): |
| | """ |
| | Initialize the sharded model loader |
| | |
| | Args: |
| | model_path: Path to the inference directory containing shards |
| | """ |
| | self.model_path = Path(model_path) |
| | self.config_path = self.model_path / "model_config.json" |
| | self.index_path = self.model_path / "model.safetensors.index.json" |
| | |
| | |
| | logger.info(f"Loading configuration from {self.config_path}") |
| | with open(self.config_path, 'r') as f: |
| | self.config = json.load(f) |
| | |
| | |
| | logger.info(f"Loading weight index from {self.index_path}") |
| | with open(self.index_path, 'r') as f: |
| | self.index = json.load(f) |
| | |
| | self.metadata = self.index.get("metadata", {}) |
| | self.weight_map = self.index.get("weight_map", {}) |
| | |
| | logger.info(f"Model: {self.metadata.get('model_type', 'unknown')}") |
| | logger.info(f"Total shards: {self.metadata.get('total_shards', 0)}") |
| | logger.info(f"Total size: {self.metadata.get('total_size', 0) / 1e9:.2f} GB") |
| | logger.info(f"Total parameters: {self.config['architectures_info']['total_parameters']}") |
| | logger.info(f"Active parameters: {self.config['architectures_info']['active_parameters']}") |
| | |
| | def get_shard_path(self, shard_name: str) -> Path: |
| | """Get full path to a shard file""" |
| | return self.model_path / shard_name |
| | |
| | def get_available_memory(self) -> Dict[str, float]: |
| | """Get available system memory""" |
| | memory = psutil.virtual_memory() |
| | result = { |
| | "ram_total_gb": memory.total / 1e9, |
| | "ram_available_gb": memory.available / 1e9, |
| | "ram_percent_used": memory.percent |
| | } |
| | |
| | if torch.cuda.is_available(): |
| | for i in range(torch.cuda.device_count()): |
| | gpu_mem = torch.cuda.get_device_properties(i).total_memory |
| | gpu_allocated = torch.cuda.memory_allocated(i) |
| | result[f"gpu_{i}_total_gb"] = gpu_mem / 1e9 |
| | result[f"gpu_{i}_available_gb"] = (gpu_mem - gpu_allocated) / 1e9 |
| | |
| | return result |
| | |
| | def load_shard(self, shard_name: str, device: str = "cpu") -> Dict[str, torch.Tensor]: |
| | """ |
| | Load a single shard file |
| | |
| | Args: |
| | shard_name: Name of the shard file |
| | device: Device to load tensors to |
| | |
| | Returns: |
| | Dictionary of weight tensors |
| | """ |
| | shard_path = self.get_shard_path(shard_name) |
| | |
| | if not shard_path.exists(): |
| | raise FileNotFoundError(f"Shard not found: {shard_path}") |
| | |
| | logger.debug(f"Loading shard: {shard_name}") |
| | return load_file(str(shard_path), device=device) |
| | |
| | def load_sharded_weights( |
| | self, |
| | device: str = "cpu", |
| | low_memory: bool = False, |
| | show_progress: bool = True |
| | ) -> Dict[str, torch.Tensor]: |
| | """ |
| | Load all sharded weights |
| | |
| | Args: |
| | device: Device to load weights to |
| | low_memory: Use memory-efficient loading |
| | show_progress: Show progress bar |
| | |
| | Returns: |
| | Dictionary of all model weights |
| | """ |
| | logger.info("Loading sharded model weights...") |
| | |
| | |
| | mem_info = self.get_available_memory() |
| | logger.info(f"Available RAM: {mem_info['ram_available_gb']:.2f} GB") |
| | if "gpu_0_available_gb" in mem_info: |
| | logger.info(f"Available GPU 0: {mem_info['gpu_0_available_gb']:.2f} GB") |
| | |
| | |
| | shard_files = sorted(set(self.weight_map.values())) |
| | total_shards = len(shard_files) |
| | |
| | logger.info(f"Loading {total_shards} shard files...") |
| | |
| | all_weights = {} |
| | |
| | |
| | pbar = tqdm(shard_files, disable=not show_progress, desc="Loading shards") |
| | |
| | for shard_name in pbar: |
| | pbar.set_description(f"Loading {shard_name}") |
| | |
| | |
| | shard_weights = self.load_shard(shard_name, device=device) |
| | |
| | |
| | all_weights.update(shard_weights) |
| | |
| | |
| | if low_memory: |
| | del shard_weights |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | |
| | logger.info(f"Loaded {len(all_weights)} weight tensors") |
| | return all_weights |
| | |
| | def get_layer_weights(self, layer_idx: int) -> List[str]: |
| | """ |
| | Get all weight keys for a specific layer |
| | |
| | Args: |
| | layer_idx: Layer index |
| | |
| | Returns: |
| | List of weight keys for that layer |
| | """ |
| | prefix = f"model.layers.{layer_idx}." |
| | return [k for k in self.weight_map.keys() if k.startswith(prefix)] |
| | |
| | def get_shard_for_weight(self, weight_key: str) -> Optional[str]: |
| | """ |
| | Get shard file name for a specific weight |
| | |
| | Args: |
| | weight_key: Weight key/name |
| | |
| | Returns: |
| | Shard file name or None |
| | """ |
| | return self.weight_map.get(weight_key) |
| | |
| | def verify_shards(self) -> Dict[str, bool]: |
| | """ |
| | Verify all shard files exist |
| | |
| | Returns: |
| | Dictionary mapping shard names to existence status |
| | """ |
| | logger.info("Verifying shard files...") |
| | |
| | shard_files = set(self.weight_map.values()) |
| | verification = {} |
| | |
| | for shard_name in tqdm(sorted(shard_files), desc="Verifying"): |
| | shard_path = self.get_shard_path(shard_name) |
| | verification[shard_name] = shard_path.exists() |
| | |
| | missing = [s for s, exists in verification.items() if not exists] |
| | |
| | if missing: |
| | logger.warning(f"Missing {len(missing)} shard files:") |
| | for shard in missing[:10]: |
| | logger.warning(f" - {shard}") |
| | if len(missing) > 10: |
| | logger.warning(f" ... and {len(missing) - 10} more") |
| | else: |
| | logger.info("✓ All shard files present") |
| | |
| | return verification |
| | |
| | def load_metadata(self) -> Dict: |
| | """Load model metadata""" |
| | return { |
| | "config": self.config, |
| | "index": self.index, |
| | "total_shards": self.metadata.get("total_shards", 0), |
| | "total_size_gb": self.metadata.get("total_size", 0) / 1e9, |
| | "architecture": self.config.get("architectures_info", {}), |
| | "num_layers": self.config.get("num_hidden_layers", 0), |
| | "hidden_size": self.config.get("hidden_size", 0), |
| | "vocab_size": self.config.get("vocab_size", 0) |
| | } |
| |
|
| |
|
| | def load_full_model( |
| | model_path: str, |
| | device: str = "cuda" if torch.cuda.is_available() else "cpu", |
| | low_memory: bool = False |
| | ): |
| | """ |
| | Convenience function to load the full model |
| | |
| | Args: |
| | model_path: Path to inference directory |
| | device: Device to load model to |
| | low_memory: Use low memory loading |
| | |
| | Returns: |
| | Loaded model weights and metadata |
| | """ |
| | loader = ShardedModelLoader(model_path) |
| | |
| | |
| | verification = loader.verify_shards() |
| | missing = sum(1 for exists in verification.values() if not exists) |
| | |
| | if missing > 0: |
| | raise FileNotFoundError( |
| | f"Cannot load model: {missing} shard files are missing. " |
| | f"Please download all 116 shard files." |
| | ) |
| | |
| | |
| | weights = loader.load_sharded_weights( |
| | device=device, |
| | low_memory=low_memory, |
| | show_progress=True |
| | ) |
| | |
| | |
| | metadata = loader.load_metadata() |
| | |
| | return weights, metadata |
| |
|
| |
|
| | def inspect_model(model_path: str): |
| | """ |
| | Inspect model structure without loading weights |
| | |
| | Args: |
| | model_path: Path to inference directory |
| | """ |
| | loader = ShardedModelLoader(model_path) |
| | |
| | print("\n" + "="*80) |
| | print("HELION-OSC MODEL INSPECTION") |
| | print("="*80) |
| | |
| | metadata = loader.load_metadata() |
| | |
| | print(f"\nModel Type: {metadata['architecture'].get('model_description', 'N/A')}") |
| | print(f"Architecture: {metadata['architecture'].get('architecture_type', 'N/A')}") |
| | print(f"Total Parameters: {metadata['architecture'].get('total_parameters', 'N/A')}") |
| | print(f"Active Parameters: {metadata['architecture'].get('active_parameters', 'N/A')}") |
| | |
| | print(f"\nModel Configuration:") |
| | print(f" Layers: {metadata['num_layers']}") |
| | print(f" Hidden Size: {metadata['hidden_size']}") |
| | print(f" Vocabulary Size: {metadata['vocab_size']}") |
| | print(f" Attention Heads: {metadata['config'].get('num_attention_heads', 'N/A')}") |
| | print(f" KV Heads: {metadata['config'].get('num_key_value_heads', 'N/A')}") |
| | |
| | print(f"\nMoE Configuration:") |
| | arch = metadata['architecture'] |
| | print(f" Number of Experts: {arch.get('num_experts', 'N/A')}") |
| | print(f" Experts per Token: {arch.get('experts_per_token', 'N/A')}") |
| | print(f" Shared Experts: {arch.get('num_shared_experts', 'N/A')}") |
| | |
| | print(f"\nStorage Information:") |
| | print(f" Total Shards: {metadata['total_shards']}") |
| | print(f" Total Size: {metadata['total_size_gb']:.2f} GB") |
| | print(f" Shard Size: ~2.8 GB each") |
| | print(f" Format: safetensors") |
| | print(f" Precision: bfloat16") |
| | |
| | print(f"\nContext Length:") |
| | print(f" Max Position Embeddings: {metadata['config'].get('max_position_embeddings', 'N/A')}") |
| | print(f" RoPE Theta: {metadata['config'].get('rope_theta', 'N/A')}") |
| | |
| | print("\n" + "="*80) |
| | |
| | |
| | print("\nVerifying shard files...") |
| | verification = loader.verify_shards() |
| | present = sum(1 for exists in verification.values() if exists) |
| | total = len(verification) |
| | |
| | print(f"\nShard Status: {present}/{total} files present") |
| | |
| | if present == total: |
| | print("✓ All shard files are available") |
| | else: |
| | print(f"✗ Missing {total - present} shard files") |
| |
|
| |
|
| | def main(): |
| | """Main CLI interface""" |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser(description="Helion-OSC Sharded Model Loader") |
| | parser.add_argument( |
| | "model_path", |
| | type=str, |
| | help="Path to inference directory" |
| | ) |
| | parser.add_argument( |
| | "--action", |
| | choices=["inspect", "verify", "load"], |
| | default="inspect", |
| | help="Action to perform" |
| | ) |
| | parser.add_argument( |
| | "--device", |
| | type=str, |
| | default="cuda" if torch.cuda.is_available() else "cpu", |
| | help="Device to load model to" |
| | ) |
| | parser.add_argument( |
| | "--low-memory", |
| | action="store_true", |
| | help="Use low memory mode" |
| | ) |
| | |
| | args = parser.parse_args() |
| | |
| | if args.action == "inspect": |
| | inspect_model(args.model_path) |
| | |
| | elif args.action == "verify": |
| | loader = ShardedModelLoader(args.model_path) |
| | loader.verify_shards() |
| | |
| | elif args.action == "load": |
| | logger.info("Loading full model...") |
| | weights, metadata = load_full_model( |
| | args.model_path, |
| | device=args.device, |
| | low_memory=args.low_memory |
| | ) |
| | logger.info(f"Successfully loaded {len(weights)} weight tensors") |
| | logger.info(f"Model ready on {args.device}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |