| | import argparse |
| | import torch |
| | from collections import Counter |
| | import os |
| | import math |
| |
|
| | |
| | TORCH_DTYPE_TO_BYTES = { |
| | |
| | torch.bool: 1, |
| | |
| | torch.float16: 2, |
| | torch.half: 2, |
| | torch.bfloat16: 2, |
| | torch.float32: 4, |
| | torch.float: 4, |
| | torch.float64: 8, |
| | torch.double: 8, |
| | |
| | torch.complex64: 8, |
| | torch.complex128: 16, |
| | torch.cfloat: 8, |
| | torch.cdouble: 16, |
| | |
| | torch.int8: 1, |
| | torch.int16: 2, |
| | torch.short: 2, |
| | torch.int32: 4, |
| | torch.int: 4, |
| | torch.int64: 8, |
| | torch.long: 8, |
| | |
| | torch.uint8: 1, |
| | torch.uint16: 2, |
| | torch.uint32: 4, |
| | torch.uint64: 8, |
| | |
| | torch.qint8: 1, |
| | torch.quint8: 1, |
| | torch.qint32: 4, |
| | torch.quint4x2: 1, |
| | } |
| |
|
| | def get_bytes_per_element(dtype): |
| | """Returns the number of bytes for a given PyTorch dtype.""" |
| | return TORCH_DTYPE_TO_BYTES.get(dtype, None) |
| |
|
| | def get_dtype_name(dtype): |
| | """Returns a readable string for a PyTorch dtype.""" |
| | return str(dtype).replace('torch.', '') |
| |
|
| | def calculate_num_elements(shape): |
| | """Calculates the total number of elements from a tensor shape tuple.""" |
| | if not shape: |
| | return 1 |
| | if 0 in shape: |
| | return 0 |
| | num_elements = 1 |
| | for dim_size in shape: |
| | num_elements *= dim_size |
| | return num_elements |
| |
|
| | def extract_tensors_from_obj(obj, prefix=""): |
| | """ |
| | Recursively extracts tensors from nested dictionaries/objects. |
| | Returns a dictionary of {key: tensor} pairs. |
| | """ |
| | tensors = {} |
| | |
| | if isinstance(obj, torch.Tensor): |
| | return {prefix or "tensor": obj} |
| | |
| | elif isinstance(obj, dict): |
| | for key, value in obj.items(): |
| | new_prefix = f"{prefix}.{key}" if prefix else key |
| | tensors.update(extract_tensors_from_obj(value, new_prefix)) |
| | |
| | elif hasattr(obj, 'state_dict') and callable(getattr(obj, 'state_dict')): |
| | |
| | state_dict = obj.state_dict() |
| | new_prefix = f"{prefix}.state_dict" if prefix else "state_dict" |
| | tensors.update(extract_tensors_from_obj(state_dict, new_prefix)) |
| | |
| | elif hasattr(obj, '__dict__'): |
| | |
| | for key, value in obj.__dict__.items(): |
| | if isinstance(value, torch.Tensor): |
| | new_prefix = f"{prefix}.{key}" if prefix else key |
| | tensors[new_prefix] = value |
| | |
| | return tensors |
| |
|
| | def inspect_pth_precision_and_size(filepath): |
| | """ |
| | Reads a .pth file, extracts tensors from it, |
| | and reports the precision (dtype), actual size, and theoretical FP32 size. |
| | """ |
| | if not os.path.exists(filepath): |
| | print(f"Error: File not found at '{filepath}'") |
| | return |
| |
|
| | try: |
| | print(f"Loading PyTorch file: {filepath}") |
| | |
| | |
| | try: |
| | obj = torch.load(filepath, map_location="cpu", weights_only=True) |
| | print("(Loaded with weights_only=True for security)\n") |
| | except TypeError: |
| | |
| | obj = torch.load(filepath, map_location="cpu") |
| | print("(Warning: Loaded without weights_only=True - older PyTorch version)\n") |
| | |
| | |
| | tensors = extract_tensors_from_obj(obj) |
| | |
| | if not tensors: |
| | print("No tensors found in the file.") |
| | return |
| |
|
| | tensor_info_list = [] |
| | dtype_counts = Counter() |
| | total_actual_mb = 0.0 |
| | total_fp32_equiv_mb = 0.0 |
| |
|
| | max_key_len = max(len("Tensor Name"), max(len(k) for k in tensors.keys())) |
| |
|
| | header = ( |
| | f"{'Tensor Name':<{max_key_len}} | " |
| | f"{'Precision (dtype)':<17} | " |
| | f"{'Shape':<20} | " |
| | f"{'Actual Size (MB)':>16} | " |
| | f"{'FP32 Equiv. (MB)':>18}" |
| | ) |
| | print(header) |
| | print( |
| | f"{'-' * max_key_len}-|-------------------|{'-' * 20}|------------------|-------------------" |
| | ) |
| |
|
| | for key, tensor in tensors.items(): |
| | dtype = tensor.dtype |
| | dtype_name = get_dtype_name(dtype) |
| | shape = tuple(tensor.shape) |
| | shape_str = str(shape) |
| | |
| | num_elements = tensor.numel() |
| | bytes_per_el_actual = get_bytes_per_element(dtype) |
| |
|
| | actual_size_mb_str = "N/A" |
| | fp32_equiv_size_mb_str = "N/A" |
| | actual_size_mb_val = 0.0 |
| |
|
| | if bytes_per_el_actual is not None: |
| | actual_bytes = num_elements * bytes_per_el_actual |
| | actual_size_mb_val = actual_bytes / (1024 * 1024) |
| | total_actual_mb += actual_size_mb_val |
| | actual_size_mb_str = f"{actual_size_mb_val:.3f}" |
| |
|
| | |
| | fp32_equiv_bytes = num_elements * 4 |
| | fp32_equiv_size_mb_val = fp32_equiv_bytes / (1024 * 1024) |
| | total_fp32_equiv_mb += fp32_equiv_size_mb_val |
| | fp32_equiv_size_mb_str = f"{fp32_equiv_size_mb_val:.3f}" |
| | else: |
| | print(f"Warning: Unknown dtype '{dtype}' for tensor '{key}'. Cannot calculate size.") |
| |
|
| | |
| | if len(shape_str) > 18: |
| | shape_str = shape_str[:15] + "..." |
| |
|
| | print( |
| | f"{key:<{max_key_len}} | " |
| | f"{dtype_name:<17} | " |
| | f"{shape_str:<20} | " |
| | f"{actual_size_mb_str:>16} | " |
| | f"{fp32_equiv_size_mb_str:>18}" |
| | ) |
| | dtype_counts[dtype_name] += 1 |
| |
|
| | print("\n--- Summary ---") |
| | print(f"Total tensors found: {len(tensors)}") |
| | if dtype_counts: |
| | print("Precision distribution:") |
| | for dtype, count in dtype_counts.most_common(): |
| | print(f" - {dtype:<12}: {count} tensor(s)") |
| | else: |
| | print("No dtypes to summarize.") |
| |
|
| | print(f"\nTotal actual size of all tensors: {total_actual_mb:.3f} MB") |
| | print(f"Total theoretical FP32 size of all tensors: {total_fp32_equiv_mb:.3f} MB") |
| |
|
| | if total_fp32_equiv_mb > 0.00001: |
| | savings_percentage = (1 - (total_actual_mb / total_fp32_equiv_mb)) * 100 |
| | print(f"Overall size reduction compared to full FP32: {savings_percentage:.2f}%") |
| | else: |
| | print("Overall size reduction cannot be calculated (no FP32 equivalent data or zero size).") |
| |
|
| | |
| | non_tensor_keys = [] |
| | if isinstance(obj, dict): |
| | for key, value in obj.items(): |
| | if key not in [k.split('.')[0] for k in tensors.keys()]: |
| | non_tensor_keys.append(f"{key}: {type(value).__name__}") |
| | |
| | if non_tensor_keys: |
| | print(f"\nNon-tensor content found:") |
| | for item in non_tensor_keys[:5]: |
| | print(f" - {item}") |
| | if len(non_tensor_keys) > 5: |
| | print(f" ... and {len(non_tensor_keys) - 5} more items") |
| |
|
| | except Exception as e: |
| | print(f"An error occurred while processing '{filepath}':") |
| | print(f" {e}") |
| | print("Please ensure it's a valid PyTorch .pth file and PyTorch is installed correctly.") |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser( |
| | description="Inspect tensor precision (dtype) and size in a PyTorch .pth file." |
| | ) |
| | parser.add_argument( |
| | "filepath", |
| | help="Path to the .pth file to inspect." |
| | ) |
| | args = parser.parse_args() |
| |
|
| | inspect_pth_precision_and_size(args.filepath) |