| | |
| |
|
| | import argparse |
| | from safetensors import safe_open |
| |
|
| | def list_safetensor_layers(filepath: str): |
| | """ |
| | Opens a .safetensors file and prints the name and shape of each tensor. |
| | |
| | Args: |
| | filepath (str): The path to the .safetensors file. |
| | """ |
| | try: |
| | print(f"\n📄 Tensors in: {filepath}\n" + "="*50) |
| | |
| | total_tensors = 0 |
| | with safe_open(filepath, framework="pt", device="cpu") as f: |
| | for key in f.keys(): |
| | tensor = f.get_tensor(key) |
| | print(f"- {key:<50} | Shape: {tensor.shape}") |
| | total_tensors += 1 |
| | |
| | print("="*50 + f"\n✅ Found {total_tensors} total tensors.\n") |
| |
|
| | except FileNotFoundError: |
| | print(f"❌ Error: The file '{filepath}' was not found.") |
| | except Exception as e: |
| | print(f"❌ An error occurred: {e}") |
| | print("Please ensure the file is a valid .safetensors file.") |
| |
|
| | if __name__ == "__main__": |
| | |
| | parser = argparse.ArgumentParser( |
| | description="List all layers (tensors) and their shapes in a .safetensors file.", |
| | formatter_class=argparse.RawTextHelpFormatter |
| | ) |
| |
|
| | parser.add_argument( |
| | "filepath", |
| | type=str, |
| | help="Path to the .safetensors file." |
| | ) |
| |
|
| | args = parser.parse_args() |
| | |
| | |
| | list_safetensor_layers(args.filepath) |
| |
|