| import torch |
|
|
| def get_precision_fac(precision: str): |
| if precision == "mixed": |
| return 2 |
| elif precision == "single": |
| return 4 |
| else: |
| raise ValueError("Precision must be either 'mixed' or 'single'") |
|
|
|
|
| def get_params_fac(model_dtype: str): |
| if model_dtype == "float16": |
| return 2 |
| elif model_dtype == "float32": |
| return 4 |
| else: |
| raise ValueError("Model dtype must be either torch.float16 or torch.float32") |
|
|
|
|
|
|
| |
|
|
| VARIANCE_FACTOR = 4 |
| MOMENTUM_FACTOR = 4 |
| OPTIMIZER_FACTOR = VARIANCE_FACTOR + MOMENTUM_FACTOR |
| FP32_GRADS_FACTOR = 4 |
| FP32_PARAM_FACTOR = 4 |
| MASTER_PARAMS_FACTOR = FP32_PARAM_FACTOR |
|
|
|
|
| def estimate_zero1_model_states_mem_needs(total_params, |
| num_gpus_per_node=1, |
| num_nodes=1, |
| cpu_offload=True, |
| additional_buffer_factor=1.5, |
| precision="mixed", |
| model_dtype = "float16", |
| ): |
| |
| total_gpus = num_nodes * num_gpus_per_node |
| |
| precision_fac = get_precision_fac(precision) |
| params_fac = get_params_fac(model_dtype) |
|
|
| if cpu_offload: |
| gpu_mem = (precision_fac * total_params) |
| cpu_mem = total_params * max(params_fac * total_gpus, (MASTER_PARAMS_FACTOR + OPTIMIZER_FACTOR + FP32_GRADS_FACTOR)) * additional_buffer_factor |
| else: |
| if precision == "mixed": |
| gpu_mem = (precision_fac * total_params) + (FP32_GRADS_FACTOR * total_params) + int((OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus) |
| else: |
| gpu_mem = (precision_fac * total_params) + (FP32_GRADS_FACTOR * total_params) + int(OPTIMIZER_FACTOR * total_params / total_gpus) |
| cpu_mem = total_params * params_fac * num_gpus_per_node * additional_buffer_factor |
|
|
| return int(cpu_mem), int(gpu_mem) |
|
|
|
|
| def estimate_zero2_model_states_mem_needs(total_params, |
| num_gpus_per_node=1, |
| num_nodes=1, |
| cpu_offload=True, |
| additional_buffer_factor=1.5, |
| precision="mixed", |
| model_dtype = "float16", |
| ): |
| |
| total_gpus = num_nodes * num_gpus_per_node |
| |
| precision_fac = get_precision_fac(precision) |
| params_fac = get_params_fac(model_dtype) |
|
|
| if cpu_offload: |
| gpu_mem = precision_fac * total_params |
| cpu_mem = total_params * max(params_fac * total_gpus, (MASTER_PARAMS_FACTOR + OPTIMIZER_FACTOR + FP32_GRADS_FACTOR)) * additional_buffer_factor |
| else: |
| if precision == "mixed": |
| gpu_mem = precision_fac * total_params + int((FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus) |
| else: |
| gpu_mem = precision_fac * total_params + int((FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * total_params / total_gpus) |
| cpu_mem = params_fac * total_params * num_gpus_per_node * additional_buffer_factor |
|
|
| return int(cpu_mem), int(gpu_mem) |
|
|
|
|
| def estimate_zero3_model_states_mem_needs(total_params, |
| largest_layer_params, |
| num_gpus_per_node=1, |
| num_nodes=1, |
| cpu_offload=True, |
| cpu_offload_params=True, |
| zero_init=True, |
| additional_buffer_factor=1.5, |
| precision="mixed", |
| model_dtype = "float16", |
| ): |
|
|
| total_gpus = num_nodes * num_gpus_per_node |
| gpus_factor = 1 / num_nodes |
|
|
| precision_fac = get_precision_fac(precision) |
| params_fac = get_params_fac(model_dtype) |
| grads_fac = precision_fac |
|
|
| largest_layer_memory = (grads_fac + precision_fac) * largest_layer_params |
|
|
| if cpu_offload: |
| if cpu_offload_params: |
| gpu_mem = largest_layer_memory |
| if zero_init: |
| cpu_mem = total_params * (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + params_fac) * gpus_factor * additional_buffer_factor |
| else: |
| cpu_mem = total_params * max(params_fac * num_gpus_per_node, (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + params_fac) * gpus_factor) * additional_buffer_factor |
| |
| else: |
| gpu_mem = max( |
| largest_layer_memory, |
| int((precision_fac) * total_params / total_gpus) |
| ) |
|
|
| if zero_init: |
| cpu_mem = total_params * (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * gpus_factor * additional_buffer_factor |
| else: |
| cpu_mem = total_params * max(params_fac * num_gpus_per_node, (MASTER_PARAMS_FACTOR + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * gpus_factor) * additional_buffer_factor |
| else: |
| if precision == "mixed": |
| gpu_mem = max( |
| int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * largest_layer_params), |
| int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR + FP32_PARAM_FACTOR) * total_params / total_gpus) |
| ) |
| else: |
| gpu_mem = max( |
| int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * largest_layer_params), |
| int((precision_fac + FP32_GRADS_FACTOR + OPTIMIZER_FACTOR) * total_params / total_gpus) |
| ) |
|
|
| if zero_init: |
| cpu_mem = largest_layer_params * params_fac * num_gpus_per_node * additional_buffer_factor |
| else: |
| cpu_mem = total_params * params_fac * num_gpus_per_node * additional_buffer_factor |
|
|
| return int(cpu_mem), int(gpu_mem), largest_layer_memory |
|
|
|
|
|
|