| | import torch |
| |
|
| | def get_precision_fac(precision: str): |
| | if precision == "mixed": |
| | return 2 |
| | elif precision == "single": |
| | return 4 |
| | else: |
| | raise ValueError("Precision must be either 'mixed' or 'single'") |
| |
|
| |
|
| | def get_params_fac(model_dtype: str): |
| | if model_dtype == "float16": |
| | return 2 |
| | elif model_dtype == "float32": |
| | return 4 |
| | else: |
| | raise ValueError("Model dtype must be either torch.float16 or torch.float32") |
| |
|
| |
|
| |
|
| | |
| |
|
| | VARIANCE_FACTOR = 4 |
| | MOMENTUM_FACTOR = 4 |
| | OPTIMIZER_FACTOR = VARIANCE_FACTOR + MOMENTUM_FACTOR |
| | FP32_GRADS_FACTOR = 4 |
| | FP32_PARAM_FACTOR = 4 |
| | MASTER_PARAMS_FACTOR = FP32_PARAM_FACTOR |
| |
|
| |
|
| | def estimate_zero1_model_states_mem_needs(total_params, |
| | num_gpus_per_node=1, |
| | num_nodes=1, |
| | cpu_offload=True, |
| | additional_buffer_factor=1.5, |
| | precision="mixed", |
| | model_dtype = "float16", |
| | ): |
| | |
| | total_gpus = num_nodes * num_gpus_per_node |
| | |
| | precision_fac = get_precision_fac(precision) |
| | params_fac = get_params_fac(model_dtype) |
| |
|
| | if cpu_offload: |
| | gpu_mem = (precision_fac * total_params) |
| | cpu_mem = total_params * max(params_fac * total_gpus, (MASTER_PARAMS_FACTOR + OPTIMIZER_FACTOR + FP32_GRADS_FACTOR)) * additional_buffer_factor |
| | else: |
| | if precision == "mixed": |
| | gpu_mem = (precision_fac * total_params) + (FP32_GRADS_FACTOR * total_params) + int((OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus) |
| | else: |
| | gpu_mem = (precision_fac * total_params) + (FP32_GRADS_FACTOR * total_params) + int(OPTIMIZER_FACTOR * total_params / total_gpus) |
| | cpu_mem = total_params * params_fac * num_gpus_per_node * additional_buffer_factor |
| |
|
| | return int(cpu_mem), int(gpu_mem) |
| |
|
| |
|
| | def estimate_zero2_model_states_mem_needs(total_params, |
| | num_gpus_per_node=1, |
| | num_nodes=1, |
| | cpu_offload=True, |
| | additional_buffer_factor=1.5, |
| | precision="mixed", |
| | model_dtype = "float16", |
| | ): |
| | |
| | total_gpus = num_nodes * num_gpus_per_node |
| | |
| | precision_fac = get_precision_fac(precision) |
| | params_fac = get_params_fac(model_dtype) |
| |
|
| | if cpu_offload: |
| | gpu_mem = precision_fac * total_params |
| | cpu_mem = total_params * max(params_fac * total_gpus, (MASTER_PARAMS_FACTOR + OPTIMIZER_FACTOR + FP32_GRADS_FACTOR)) * additional_buffer_factor |
| | else: |
| | if precision == "mixed": |
| | gpu_mem = precision_fac * total_params + int((FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus) |
| | else: |
| | gpu_mem = precision_fac * total_params + int((FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * total_params / total_gpus) |
| | cpu_mem = params_fac * total_params * num_gpus_per_node * additional_buffer_factor |
| |
|
| | return int(cpu_mem), int(gpu_mem) |
| |
|
| |
|
| | def estimate_zero3_model_states_mem_needs(total_params, |
| | largest_layer_params, |
| | num_gpus_per_node=1, |
| | num_nodes=1, |
| | cpu_offload=True, |
| | cpu_offload_params=True, |
| | zero_init=True, |
| | additional_buffer_factor=1.5, |
| | precision="mixed", |
| | model_dtype = "float16", |
| | ): |
| |
|
| | total_gpus = num_nodes * num_gpus_per_node |
| | gpus_factor = 1 / num_nodes |
| |
|
| | precision_fac = get_precision_fac(precision) |
| | params_fac = get_params_fac(model_dtype) |
| | grads_fac = precision_fac |
| |
|
| | largest_layer_memory = (grads_fac + precision_fac) * largest_layer_params |
| |
|
| | if cpu_offload: |
| | if cpu_offload_params: |
| | gpu_mem = largest_layer_memory |
| | if zero_init: |
| | cpu_mem = total_params * (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + params_fac) * gpus_factor * additional_buffer_factor |
| | else: |
| | cpu_mem = total_params * max(params_fac * num_gpus_per_node, (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + params_fac) * gpus_factor) * additional_buffer_factor |
| | |
| | else: |
| | gpu_mem = max( |
| | largest_layer_memory, |
| | int((precision_fac) * total_params / total_gpus) |
| | ) |
| |
|
| | if zero_init: |
| | cpu_mem = total_params * (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * gpus_factor * additional_buffer_factor |
| | else: |
| | cpu_mem = total_params * max(params_fac * num_gpus_per_node, (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * gpus_factor) * additional_buffer_factor |
| | else: |
| | if precision == "mixed": |
| | gpu_mem = max( |
| | int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * largest_layer_params), |
| | int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus) |
| | ) |
| | else: |
| | gpu_mem = max( |
| | int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * largest_layer_params), |
| | int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * total_params / total_gpus) |
| | ) |
| |
|
| | if zero_init: |
| | cpu_mem = largest_layer_params * params_fac * num_gpus_per_node * additional_buffer_factor |
| | else: |
| | cpu_mem = total_params * params_fac * num_gpus_per_node * additional_buffer_factor |
| |
|
| | return int(cpu_mem), int(gpu_mem), largest_layer_memory |
| |
|
| |
|
| |
|