Buckets:
MisterAI/LocalAI_Demo_backends / cpu-diffusers.upgrade-tmp /venv /lib /python3.10 /site-packages /accelerate /utils /dataclasses.py
| # Copyright 2022 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| General namespace and dataclass related classes | |
| """ | |
| import argparse | |
| import copy | |
| import enum | |
| import functools | |
| import logging | |
| import os | |
| import warnings | |
| from collections.abc import Iterable | |
| from contextlib import contextmanager | |
| from dataclasses import dataclass, field | |
| from datetime import timedelta | |
| from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, get_args | |
| import torch | |
| from .constants import ( | |
| BETA_CP_AVAILABLE_PYTORCH_VERSION, | |
| BETA_TP_AVAILABLE_PYTORCH_VERSION, | |
| BETA_TP_AVAILABLE_TRANSFORMERS_VERSION, | |
| FSDP2_PYTORCH_VERSION, | |
| FSDP_AUTO_WRAP_POLICY, | |
| FSDP_BACKWARD_PREFETCH, | |
| FSDP_SHARDING_STRATEGY, | |
| MITA_PROFILING_AVAILABLE_PYTORCH_VERSION, | |
| XPU_PROFILING_AVAILABLE_PYTORCH_VERSION, | |
| ) | |
| from .environment import parse_flag_from_env, str_to_bool | |
| from .imports import ( | |
| is_cuda_available, | |
| is_hpu_available, | |
| is_mlu_available, | |
| is_msamp_available, | |
| is_musa_available, | |
| is_npu_available, | |
| is_torchao_available, | |
| is_transformer_engine_available, | |
| is_xpu_available, | |
| ) | |
| from .versions import compare_versions, is_torch_version | |
| if TYPE_CHECKING: | |
| # Mock imports for type checking | |
| from torchao.float8 import Float8LinearConfig | |
| logger = logging.getLogger(__name__) | |
| class KwargsHandler: | |
| """ | |
| Internal mixin that implements a `to_kwargs()` method for a dataclass. | |
| """ | |
| def to_dict(self): | |
| return copy.deepcopy(self.__dict__) | |
| def to_kwargs(self): | |
| """ | |
| Returns a dictionary containing the attributes with values different from the default of this class. | |
| """ | |
| # import clear_environment here to avoid circular import problem | |
| from .environment import clear_environment | |
| with clear_environment(): | |
| default_dict = self.__class__().to_dict() | |
| this_dict = self.to_dict() | |
| return {k: v for k, v in this_dict.items() if default_dict[k] != v} | |
| class EnumWithContains(enum.EnumMeta): | |
| "A metaclass that adds the ability to check if `self` contains an item with the `in` operator" | |
| def __contains__(cls, item): | |
| try: | |
| cls(item) | |
| except ValueError: | |
| return False | |
| return True | |
| class BaseEnum(enum.Enum, metaclass=EnumWithContains): | |
| "An enum class that can get the value of an item with `str(Enum.key)`" | |
| def __str__(self): | |
| return self.value | |
| def list(cls): | |
| "Method to list all the possible items in `cls`" | |
| return list(map(str, cls)) | |
| class AutocastKwargs(KwargsHandler): | |
| """ | |
| Use this object in your [`Accelerator`] to customize how `torch.autocast` behaves. Please refer to the | |
| documentation of this [context manager](https://pytorch.org/docs/stable/amp.html#torch.autocast) for more | |
| information on each argument. | |
| Example: | |
| ```python | |
| from accelerate import Accelerator | |
| from accelerate.utils import AutocastKwargs | |
| kwargs = AutocastKwargs(cache_enabled=True) | |
| accelerator = Accelerator(kwargs_handlers=[kwargs]) | |
| ``` | |
| """ | |
| enabled: bool = True | |
| cache_enabled: Optional[bool] = None | |
| class DDPCommunicationHookType(BaseEnum): | |
| """ | |
| Represents a type of communication hook used in DDP. | |
| Values: | |
| - **NO** -- no communication hook | |
| - **FP16** -- DDP communication hook to compress the gradients in FP16 | |
| - **BF16** -- DDP communication hook to compress the gradients in BF16 | |
| - **POWER_SGD** -- DDP communication hook to use PowerSGD | |
| - **BATCHED_POWER_SGD** -- DDP communication hook to use batched PowerSGD | |
| """ | |
| NO = "no" | |
| FP16 = "fp16" | |
| BF16 = "bf16" | |
| POWER_SGD = "power_sgd" | |
| BATCHED_POWER_SGD = "batched_power_sgd" | |
| class DistributedDataParallelKwargs(KwargsHandler): | |
| """ | |
| Use this object in your [`Accelerator`] to customize how your model is wrapped in a | |
| `torch.nn.parallel.DistributedDataParallel`. Please refer to the documentation of this | |
| [wrapper](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) for more | |
| information on each argument. | |
| <Tip warning={true}> | |
| `gradient_as_bucket_view` is only available in PyTorch 1.7.0 and later versions. | |
| `static_graph` is only available in PyTorch 1.11.0 and later versions. | |
| </Tip> | |
| Example: | |
| ```python | |
| from accelerate import Accelerator | |
| from accelerate.utils import DistributedDataParallelKwargs | |
| kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) | |
| accelerator = Accelerator(kwargs_handlers=[kwargs]) | |
| ``` | |
| """ | |
| dim: int = 0 | |
| broadcast_buffers: bool = True | |
| bucket_cap_mb: int = 25 | |
| find_unused_parameters: bool = False | |
| check_reduction: bool = False | |
| gradient_as_bucket_view: bool = False | |
| static_graph: bool = False | |
| comm_hook: DDPCommunicationHookType = DDPCommunicationHookType.NO | |
| comm_wrapper: Literal[ | |
| DDPCommunicationHookType.NO, | |
| DDPCommunicationHookType.FP16, | |
| DDPCommunicationHookType.BF16, | |
| ] = DDPCommunicationHookType.NO | |
| comm_state_option: dict = field(default_factory=dict) | |
| def to_dict(self, ignore_keys=("comm_hook", "comm_wrapper", "comm_state_option")): | |
| return {k: v for k, v in super().to_dict().items() if k not in ignore_keys} | |
| def register_comm_hook(self, model): | |
| from torch.distributed.algorithms.ddp_comm_hooks import ( | |
| default_hooks, | |
| powerSGD_hook, | |
| ) | |
| hook_map: dict[DDPCommunicationHookType, Callable] = { | |
| DDPCommunicationHookType.FP16: default_hooks.fp16_compress_hook, | |
| DDPCommunicationHookType.BF16: default_hooks.bf16_compress_hook, | |
| DDPCommunicationHookType.POWER_SGD: powerSGD_hook.powerSGD_hook, | |
| DDPCommunicationHookType.BATCHED_POWER_SGD: powerSGD_hook.batched_powerSGD_hook, | |
| } | |
| wrapper_map: dict[DDPCommunicationHookType, Callable] = { | |
| DDPCommunicationHookType.FP16: default_hooks.fp16_compress_wrapper, | |
| DDPCommunicationHookType.BF16: default_hooks.bf16_compress_wrapper, | |
| } | |
| hook: Optional[Callable] = hook_map.get(self.comm_hook) | |
| wrapper: Optional[Callable] = wrapper_map.get(self.comm_wrapper) | |
| if hook and wrapper: | |
| hook = wrapper(hook) | |
| if hook: | |
| state = ( | |
| powerSGD_hook.PowerSGDState(None, **self.comm_state_option) | |
| if self.comm_hook | |
| in ( | |
| DDPCommunicationHookType.POWER_SGD, | |
| DDPCommunicationHookType.BATCHED_POWER_SGD, | |
| ) | |
| else None | |
| ) | |
| model.register_comm_hook( | |
| state=state, | |
| hook=hook, | |
| ) | |
| class GradScalerKwargs(KwargsHandler): | |
| """ | |
| Use this object in your [`Accelerator`] to customize the behavior of mixed precision, specifically how the | |
| `torch.amp.GradScaler` or `torch.cuda.amp.GradScaler` used is created. Please refer to the documentation of this | |
| [scaler](https://pytorch.org/docs/stable/amp.html?highlight=gradscaler) for more information on each argument. | |
| <Tip warning={true}> | |
| `torch.cuda.amp.GradScaler` is only available in PyTorch 1.5.0 and later versions, and `torch.amp.GradScaler` is | |
| only available in PyTorch 2.4.0 and later versions. | |
| </Tip> | |
| Example: | |
| ```python | |
| from accelerate import Accelerator | |
| from accelerate.utils import GradScalerKwargs | |
| kwargs = GradScalerKwargs(backoff_factor=0.25) | |
| accelerator = Accelerator(kwargs_handlers=[kwargs]) | |
| ``` | |
| """ | |
| init_scale: float = 65536.0 | |
| growth_factor: float = 2.0 | |
| backoff_factor: float = 0.5 | |
| growth_interval: int = 2000 | |
| enabled: bool = True | |
| class InitProcessGroupKwargs(KwargsHandler): | |
| """ | |
| Use this object in your [`Accelerator`] to customize the initialization of the distributed processes. Please refer | |
| to the documentation of this | |
| [method](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more | |
| information on each argument. | |
| Note: If `timeout` is set to `None`, the default will be based upon how `backend` is set. | |
| ```python | |
| from datetime import timedelta | |
| from accelerate import Accelerator | |
| from accelerate.utils import InitProcessGroupKwargs | |
| kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=800)) | |
| accelerator = Accelerator(kwargs_handlers=[kwargs]) | |
| ``` | |
| """ | |
| backend: Optional[str] = "nccl" | |
| init_method: Optional[str] = None | |
| timeout: Optional[timedelta] = None | |
| def __post_init__(self): | |
| if self.timeout is None: | |
| seconds = 1800 if self.backend != "nccl" else 600 | |
| self.timeout = timedelta(seconds=seconds) | |
| # Literals | |
| Backend = Literal["MSAMP", "TE"] | |
| OptLevel = Literal["O1", "O2"] | |
| FP8Format = Literal["HYBRID", "E4M3", "E5M2"] | |
| AmaxComputeAlgorithm = Literal["max", "most_recent"] | |
| # FP8 training recipe kwargs | |
| class AORecipeKwargs(KwargsHandler): | |
| """ | |
| Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision | |
| training with `torchao` FP8. | |
| Args: | |
| config (`torchao.float8.Float8LinearConfig`, *optional*, default to `None`): | |
| The configuration for the FP8 training. If `None`, a default config will be created with sensible | |
| defaults for most use cases: | |
| - `pad_inner_dim=True`: Pads matrix dimensions to be divisible by 16, required for `torch._scaled_mm` | |
| operations to prevent runtime errors. | |
| - `enable_fsdp_float8_all_gather=True`: Enables FP8 all-gather for FSDP2. This provides memory bandwidth | |
| savings by casting parameters before the all-gather operation, saving 50% bandwidth compared to BF16. | |
| You can override these defaults by providing your own `Float8LinearConfig` instance. | |
| module_filter_func (`Callable`, *optional*, default to `None`): | |
| Optional function that must take in a module and layer name, and returns a boolean indicating whether the | |
| module should be converted to FP8. Defaults to `accelerate.utils.ao.filter_linear_layers`. See it for an | |
| example. | |
| """ | |
| config: Optional["Float8LinearConfig"] = None | |
| module_filter_func: Optional[Callable] = None | |
| pad_inner_dim: Optional[bool] = None | |
| enable_fsdp_float8_all_gather: Optional[bool] = None | |
| def __post_init__(self): | |
| env_prefix = "ACCELERATE_FP8_" | |
| if not is_torchao_available(): | |
| raise ImportError("TorchAO is not available. Please install it or use a different backend.") | |
| if self.config is None: | |
| from torchao.float8 import Float8LinearConfig | |
| # Check environment variables for overrides | |
| if self.pad_inner_dim is None: | |
| self.pad_inner_dim = parse_flag_from_env(env_prefix + "PAD_INNER_DIM", default=True) | |
| if self.enable_fsdp_float8_all_gather is None: | |
| self.enable_fsdp_float8_all_gather = parse_flag_from_env( | |
| env_prefix + "ENABLE_FSDP_FLOAT8_ALL_GATHER", default=True | |
| ) | |
| self.config = Float8LinearConfig( | |
| pad_inner_dim=self.pad_inner_dim, | |
| enable_fsdp_float8_all_gather=self.enable_fsdp_float8_all_gather, | |
| ) | |
| class TERecipeKwargs(KwargsHandler): | |
| """ | |
| Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision | |
| training with `transformer-engine`. | |
| <Tip> | |
| For more information on the args, please refer to the API | |
| [documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html). | |
| </Tip> | |
| ```python | |
| from accelerate import Accelerator | |
| from accelerate.utils import TERecipeKwargs | |
| kwargs = TERecipeKwargs(fp8_format="HYBRID") | |
| accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[kwargs]) | |
| ``` | |
| Args: | |
| use_autocast_during_eval (`bool`, *optional*, default to `False`): | |
| Whether to use FP8 autocast during eval mode. Generally better metrics are found when this is `False`. | |
| margin (`int`, *optional*, default to 0): | |
| The margin to use for the gradient scaling. | |
| interval (`int`, *optional*, default to 1): | |
| The interval to use for how often the scaling factor is recomputed. | |
| fp8_format (`str`, *optional*, default to "HYBRID"): | |
| The format to use for the FP8 recipe. Must be one of `HYBRID`, `E4M3` or `E5M2`. (Generally `HYBRID` for | |
| training, `E4M3` or `E5M2` for evaluation) | |
| amax_history_len (`int`, *optional*, default to 1024): | |
| The length of the history to use for the scaling factor computation | |
| amax_compute_algo (`str`, *optional*, default to "most_recent"): | |
| The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`. | |
| override_linear_precision (`tuple` of three `bool`, *optional*, default to `(False, False, False)`): | |
| Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. | |
| """ | |
| use_autocast_during_eval: Optional[bool] = None | |
| margin: Optional[int] = None | |
| interval: Optional[int] = None | |
| fp8_format: FP8Format = None | |
| amax_history_len: Optional[int] = None | |
| amax_compute_algo: AmaxComputeAlgorithm = None | |
| override_linear_precision: tuple[bool, bool, bool] = None | |
| use_mxfp8_block_scaling: Optional[bool] = None | |
| def __post_init__(self): | |
| env_prefix = "ACCELERATE_FP8_" | |
| if not is_transformer_engine_available(): | |
| raise ImportError("TransformerEngine is not available. Please install it or use a different backend.") | |
| if self.use_autocast_during_eval is None: | |
| self.use_autocast_during_eval = parse_flag_from_env(env_prefix + "USE_AUTOCAST_DURING_EVAL") | |
| if self.margin is None: | |
| self.margin = int(os.environ.get(env_prefix + "MARGIN", 0)) | |
| if self.interval is None: | |
| self.interval = int(os.environ.get(env_prefix + "INTERVAL", 1)) | |
| if self.fp8_format is None: | |
| self.fp8_format = os.environ.get(env_prefix + "FORMAT", "HYBRID") | |
| self.fp8_format = self.fp8_format.upper() | |
| if self.fp8_format not in get_args(FP8Format): | |
| raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.") | |
| if self.amax_compute_algo is None: | |
| self.amax_compute_algo = os.environ.get(env_prefix + "AMAX_COMPUTE_ALGO", "most_recent") | |
| self.amax_compute_algo = self.amax_compute_algo.lower() | |
| if self.amax_compute_algo not in get_args(AmaxComputeAlgorithm): | |
| raise ValueError(f"`amax_compute_algo` must be one of {' or '.join(get_args(AmaxComputeAlgorithm))}") | |
| if self.amax_history_len is None: | |
| self.amax_history_len = int(os.environ.get(env_prefix + "AMAX_HISTORY_LEN", 1024)) | |
| if self.override_linear_precision is None: | |
| fprop = parse_flag_from_env(env_prefix + "OVERRIDE_FPROP") | |
| dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD") | |
| wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD") | |
| self.override_linear_precision = (fprop, dgrad, wgrad) | |
| if self.use_mxfp8_block_scaling is None: | |
| self.use_mxfp8_block_scaling = parse_flag_from_env(env_prefix + "USE_MXFP8_BLOCK_SCALING") | |
| class MSAMPRecipeKwargs(KwargsHandler): | |
| """ | |
| Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision | |
| training with `ms-amp`. | |
| """ | |
| opt_level: OptLevel = None | |
| def __post_init__(self): | |
| env_prefix = "ACCELERATE_FP8_" | |
| if self.opt_level is None: | |
| self.opt_level = os.environ.get(env_prefix + "OPT_LEVEL", "O2") | |
| if self.opt_level not in get_args(OptLevel): | |
| raise ValueError(f"`opt_level` must be one of {' or '.join(get_args(OptLevel))}") | |
| class FP8RecipeKwargs(TERecipeKwargs, MSAMPRecipeKwargs): | |
| """ | |
| Deprecated. Please use one of the proper FP8 recipe kwargs classes such as `TERecipeKwargs` or `MSAMPRecipeKwargs` | |
| instead. | |
| """ | |
| backend: Backend = None | |
| def __post_init__(self): | |
| env_prefix = "ACCELERATE_FP8_" | |
| warnings.warn( | |
| "FP8RecipeKwargs is deprecated and will be removed in Accelerate v2.0.0. " | |
| "Please use one of the proper FP8 recipe kwargs classes such as TERecipeKwargs or MSAMPRecipeKwargs instead.", | |
| FutureWarning, | |
| ) | |
| default_backend = "msamp" if is_msamp_available() else "te" | |
| if self.backend is None: | |
| self.backend = os.environ.get(env_prefix + "BACKEND", default_backend) | |
| self.backend = self.backend.upper() | |
| if self.backend not in get_args(Backend): | |
| raise ValueError("`backend` must be 'MSAMP' or 'TE' (TransformerEngine) to use `FP8RecipeKwargs`.") | |
| super().__post_init__() | |
| # Literal | |
| ProfilerActivity = Literal["cpu", "xpu", "mtia", "cuda", "hpu"] | |
| class ProfileKwargs(KwargsHandler): | |
| """ | |
| Use this object in your [`Accelerator`] to customize the initialization of the profiler. Please refer to the | |
| documentation of this [context manager](https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile) for | |
| more information on each argument. | |
| <Tip warning={true}> | |
| `torch.profiler` is only available in PyTorch 1.8.1 and later versions. | |
| </Tip> | |
| Example: | |
| ```python | |
| from accelerate import Accelerator | |
| from accelerate.utils import ProfileKwargs | |
| kwargs = ProfileKwargs(activities=["cpu", "cuda"]) | |
| accelerator = Accelerator(kwargs_handlers=[kwargs]) | |
| ``` | |
| Args: | |
| activities (`List[str]`, *optional*, default to `None`): | |
| The list of activity groups to use in profiling. Must be one of `"cpu"`, `"xpu"`, `"mtia"`, "hpu" or | |
| `"cuda"`. | |
| schedule_option (`Dict[str, int]`, *optional*, default to `None`): | |
| The schedule option to use for the profiler. Available keys are `wait`, `warmup`, `active`, `repeat` and | |
| `skip_first`. The profiler will skip the first `skip_first` steps, then wait for `wait` steps, then do the | |
| warmup for the next `warmup` steps, then do the active recording for the next `active` steps and then | |
| repeat the cycle starting with `wait` steps. The optional number of cycles is specified with the `repeat` | |
| parameter, the zero value means that the cycles will continue until the profiling is finished. | |
| on_trace_ready (`Callable`, *optional*, default to `None`): | |
| Callable that is called at each step when schedule returns `ProfilerAction.RECORD_AND_SAVE` during the | |
| profiling. | |
| record_shapes (`bool`, *optional*, default to `False`): | |
| Save information about operator’s input shapes. | |
| profile_memory (`bool`, *optional*, default to `False`): | |
| Track tensor memory allocation/deallocation | |
| with_stack (`bool`, *optional*, default to `False`): | |
| Record source information (file and line number) for the ops. | |
| with_flops (`bool`, *optional*, default to `False`): | |
| Use formula to estimate the FLOPS of specific operators | |
| with_modules (`bool`, *optional*, default to `False`): | |
| Record module hierarchy (including function names) corresponding to the callstack of the op. | |
| output_trace_dir (`str`, *optional*, default to `None`): | |
| Exports the collected trace in Chrome JSON format. Chrome use 'chrome://tracing' view json file. Defaults | |
| to None, which means profiling does not store json files. | |
| """ | |
| activities: Optional[list[ProfilerActivity]] = None | |
| schedule_option: Optional[dict[str, int]] = None | |
| on_trace_ready: Optional[Callable] = None | |
| record_shapes: bool = False | |
| profile_memory: bool = False | |
| with_stack: bool = False | |
| with_flops: bool = False | |
| with_modules: bool = False | |
| output_trace_dir: Optional[str] = None | |
| def _get_profiler_activity(self, activity: ProfilerActivity) -> torch.profiler.ProfilerActivity: | |
| """Get the profiler activity from the string. | |
| Args: | |
| activity (str): The profiler activity name. | |
| Returns: | |
| torch.profiler.ProfilerActivity: The profiler activity. | |
| """ | |
| profiler_activity_map: dict[str, torch.profiler.ProfilerActivity] = { | |
| "cpu": torch.profiler.ProfilerActivity.CPU, | |
| "cuda": torch.profiler.ProfilerActivity.CUDA, | |
| } | |
| if is_hpu_available(): | |
| profiler_activity_map["hpu"] = torch.profiler.ProfilerActivity.HPU | |
| if is_torch_version(">=", XPU_PROFILING_AVAILABLE_PYTORCH_VERSION): | |
| if torch.xpu.is_available(): | |
| profiler_activity_map["xpu"] = torch.profiler.ProfilerActivity.XPU | |
| if is_torch_version(">=", MITA_PROFILING_AVAILABLE_PYTORCH_VERSION): | |
| if torch.mtia.is_available(): | |
| profiler_activity_map["mtia"] = torch.profiler.ProfilerActivity.MTIA | |
| if activity not in profiler_activity_map: | |
| raise ValueError(f"Invalid profiler activity: {activity}. Must be one of {list(profiler_activity_map)}.") | |
| return profiler_activity_map[activity] | |
| def build(self) -> torch.profiler.profile: | |
| """ | |
| Build a profiler object with the current configuration. | |
| Returns: | |
| torch.profiler.profile: The profiler object. | |
| """ | |
| activities: Optional[list[ProfilerActivity]] = None | |
| if self.activities is not None: | |
| activities = [self._get_profiler_activity(activity) for activity in self.activities] | |
| schedule: Optional[torch.profiler.schedule] = None | |
| if self.schedule_option is not None: | |
| schedule = torch.profiler.schedule(**self.schedule_option) | |
| return torch.profiler.profile( | |
| activities=activities, | |
| schedule=schedule, | |
| on_trace_ready=self.on_trace_ready, | |
| record_shapes=self.record_shapes, | |
| profile_memory=self.profile_memory, | |
| with_stack=self.with_stack, | |
| with_flops=self.with_flops, | |
| with_modules=self.with_modules, | |
| ) | |
| class DistributedType(str, enum.Enum): | |
| """ | |
| Represents a type of distributed environment. | |
| Values: | |
| - **NO** -- Not a distributed environment, just a single process. | |
| - **MULTI_CPU** -- Distributed on multiple CPU nodes. | |
| - **MULTI_GPU** -- Distributed on multiple GPUs. | |
| - **MULTI_MLU** -- Distributed on multiple MLUs. | |
| - **MULTI_SDAA** -- Distributed on multiple SDAAs. | |
| - **MULTI_MUSA** -- Distributed on multiple MUSAs. | |
| - **MULTI_NPU** -- Distributed on multiple NPUs. | |
| - **MULTI_XPU** -- Distributed on multiple XPUs. | |
| - **MULTI_HPU** -- Distributed on multiple HPUs. | |
| - **MULTI_NEURON** -- Distributed on multiple Neuron cores. | |
| - **DEEPSPEED** -- Using DeepSpeed. | |
| - **XLA** -- Using TorchXLA. | |
| """ | |
| # Subclassing str as well as Enum allows the `DistributedType` to be JSON-serializable out of the box. | |
| NO = "NO" | |
| MULTI_CPU = "MULTI_CPU" | |
| MULTI_GPU = "MULTI_GPU" | |
| MULTI_NPU = "MULTI_NPU" | |
| MULTI_MLU = "MULTI_MLU" | |
| MULTI_SDAA = "MULTI_SDAA" | |
| MULTI_MUSA = "MULTI_MUSA" | |
| MULTI_XPU = "MULTI_XPU" | |
| DEEPSPEED = "DEEPSPEED" | |
| FSDP = "FSDP" | |
| XLA = "XLA" | |
| MEGATRON_LM = "MEGATRON_LM" | |
| MULTI_HPU = "MULTI_HPU" | |
| MULTI_NEURON = "MULTI_NEURON" | |
| class SageMakerDistributedType(str, enum.Enum): | |
| """ | |
| Represents a type of distributed environment. | |
| Values: | |
| - **NO** -- Not a distributed environment, just a single process. | |
| - **DATA_PARALLEL** -- using sagemaker distributed data parallelism. | |
| - **MODEL_PARALLEL** -- using sagemaker distributed model parallelism. | |
| """ | |
| # Subclassing str as well as Enum allows the `SageMakerDistributedType` to be JSON-serializable out of the box. | |
| NO = "NO" | |
| DATA_PARALLEL = "DATA_PARALLEL" | |
| MODEL_PARALLEL = "MODEL_PARALLEL" | |
| class FP8BackendType(str, enum.Enum): | |
| """ | |
| Represents the backend used for FP8. | |
| Values: | |
| - **TE** -- using TransformerEngine. | |
| - **MSAMP** -- using msamp. | |
| """ | |
| # Subclassing str as well as Enum allows the `FP8BackendType` to be JSON-serializable out of the box. | |
| NO = "NO" | |
| TE = "TE" | |
| MSAMP = "MSAMP" | |
| AO = "AO" | |
| class ComputeEnvironment(str, enum.Enum): | |
| """ | |
| Represents a type of the compute environment. | |
| Values: | |
| - **LOCAL_MACHINE** -- private/custom cluster hardware. | |
| - **AMAZON_SAGEMAKER** -- Amazon SageMaker as compute environment. | |
| """ | |
| # Subclassing str as well as Enum allows the `ComputeEnvironment` to be JSON-serializable out of the box. | |
| LOCAL_MACHINE = "LOCAL_MACHINE" | |
| AMAZON_SAGEMAKER = "AMAZON_SAGEMAKER" | |
| class DynamoBackend(str, BaseEnum): | |
| """ | |
| Represents a dynamo backend (see https://pytorch.org/docs/stable/torch.compiler.html). | |
| Values: | |
| - **NO** -- Do not use torch dynamo. | |
| - **EAGER** -- Uses PyTorch to run the extracted GraphModule. This is quite useful in debugging TorchDynamo | |
| issues. | |
| - **AOT_EAGER** -- Uses AotAutograd with no compiler, i.e, just using PyTorch eager for the AotAutograd's | |
| extracted forward and backward graphs. This is useful for debugging, and unlikely to give speedups. | |
| - **INDUCTOR** -- Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging codegened Triton | |
| kernels. [Read | |
| more](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747) | |
| - **AOT_TS_NVFUSER** -- nvFuser with AotAutograd/TorchScript. [Read | |
| more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593) | |
| - **NVPRIMS_NVFUSER** -- nvFuser with PrimTorch. [Read | |
| more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593) | |
| - **CUDAGRAPHS** -- cudagraphs with AotAutograd. [Read more](https://github.com/pytorch/torchdynamo/pull/757) | |
| - **OFI** -- Uses Torchscript optimize_for_inference. Inference only. [Read | |
| more](https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html) | |
| - **FX2TRT** -- Uses Nvidia TensorRT for inference optimizations. Inference only. [Read | |
| more](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst) | |
| - **ONNXRT** -- Uses ONNXRT for inference on CPU/GPU. Inference only. [Read more](https://onnxruntime.ai/) | |
| - **TENSORRT** -- Uses ONNXRT to run TensorRT for inference optimizations. [Read | |
| more](https://github.com/onnx/onnx-tensorrt) | |
| - **AOT_TORCHXLA_TRACE_ONCE** -- Uses Pytorch/XLA with TorchDynamo optimization, for training. [Read | |
| more](https://github.com/pytorch/xla/blob/r2.0/docs/dynamo.md) | |
| - **TORCHXLA_TRACE_ONCE** -- Uses Pytorch/XLA with TorchDynamo optimization, for inference. [Read | |
| more](https://github.com/pytorch/xla/blob/r2.0/docs/dynamo.md) | |
| - **TVM** -- Uses Apache TVM for inference optimizations. [Read more](https://tvm.apache.org/) | |
| - **HPU_BACKEND** -- Uses HPU backend for inference optimizations. | |
| """ | |
| # Subclassing str as well as Enum allows the `SageMakerDistributedType` to be JSON-serializable out of the box. | |
| NO = "NO" | |
| EAGER = "EAGER" | |
| AOT_EAGER = "AOT_EAGER" | |
| INDUCTOR = "INDUCTOR" | |
| AOT_TS_NVFUSER = "AOT_TS_NVFUSER" | |
| NVPRIMS_NVFUSER = "NVPRIMS_NVFUSER" | |
| CUDAGRAPHS = "CUDAGRAPHS" | |
| OFI = "OFI" | |
| FX2TRT = "FX2TRT" | |
| ONNXRT = "ONNXRT" | |
| TENSORRT = "TENSORRT" | |
| AOT_TORCHXLA_TRACE_ONCE = "AOT_TORCHXLA_TRACE_ONCE" | |
| TORCHXLA_TRACE_ONCE = "TORCHXLA_TRACE_ONCE" | |
| TVM = "TVM" | |
| HPU_BACKEND = "HPU_BACKEND" | |
| class LoggerType(BaseEnum): | |
| """Represents a type of supported experiment tracker | |
| Values: | |
| - **ALL** -- all available trackers in the environment that are supported | |
| - **TENSORBOARD** -- TensorBoard as an experiment tracker | |
| - **WANDB** -- wandb as an experiment tracker | |
| - **TRACKIO** -- trackio as an experiment tracker | |
| - **COMETML** -- comet_ml as an experiment tracker | |
| - **MLFLOW** -- mlflow as an experiment tracker | |
| - **CLEARML** -- clearml as an experiment tracker | |
| - **DVCLIVE** -- dvclive as an experiment tracker | |
| - **SWANLAB** -- swanlab as an experiment tracker | |
| """ | |
| ALL = "all" | |
| AIM = "aim" | |
| TENSORBOARD = "tensorboard" | |
| WANDB = "wandb" | |
| TRACKIO = "trackio" | |
| COMETML = "comet_ml" | |
| MLFLOW = "mlflow" | |
| CLEARML = "clearml" | |
| DVCLIVE = "dvclive" | |
| SWANLAB = "swanlab" | |
| class PrecisionType(str, BaseEnum): | |
| """Represents a type of precision used on floating point values | |
| Values: | |
| - **NO** -- using full precision (FP32) | |
| - **FP16** -- using half precision | |
| - **BF16** -- using brain floating point precision | |
| """ | |
| NO = "no" | |
| FP8 = "fp8" | |
| FP16 = "fp16" | |
| BF16 = "bf16" | |
| class RNGType(BaseEnum): | |
| TORCH = "torch" | |
| CUDA = "cuda" | |
| MLU = "mlu" | |
| SDAA = "sdaa" | |
| MUSA = "musa" | |
| NPU = "npu" | |
| XLA = "xla" | |
| XPU = "xpu" | |
| HPU = "hpu" | |
| NEURON = "neuron" | |
| GENERATOR = "generator" | |
| class CustomDtype(enum.Enum): | |
| r""" | |
| An enum that contains multiple custom dtypes that can be used for `infer_auto_device_map`. | |
| """ | |
| FP8 = "fp8" | |
| INT4 = "int4" | |
| INT2 = "int2" | |
| # data classes | |
| class TensorInformation: | |
| shape: torch.Size | |
| dtype: torch.dtype | |
| class DataLoaderConfiguration: | |
| """ | |
| Configuration for dataloader-related items when calling `accelerator.prepare`. | |
| Args: | |
| split_batches (`bool`, defaults to `False`): | |
| Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If | |
| `True`, the actual batch size used will be the same on any kind of distributed processes, but it must be a | |
| round multiple of `num_processes` you are using. If `False`, actual batch size used will be the one set in | |
| your script multiplied by the number of processes. | |
| dispatch_batches (`bool`, defaults to `None`): | |
| If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process | |
| and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose | |
| underlying dataset is an `IterableDataset`, `False` otherwise. | |
| even_batches (`bool`, defaults to `True`): | |
| If set to `True`, in cases where the total batch size across all processes does not exactly divide the | |
| dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among | |
| all workers. | |
| use_seedable_sampler (`bool`, defaults to `False`): | |
| Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`]). Ensures | |
| training results are fully reproducible using a different sampling technique. While seed-to-seed results | |
| may differ, on average the differences are negligible when using multiple different seeds to compare. | |
| Should also be ran with [`~utils.set_seed`] for the best results. | |
| data_seed (`int`, defaults to `None`): | |
| The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator | |
| will use the current default seed from torch. | |
| non_blocking (`bool`, defaults to `False`): | |
| If set to `True`, the dataloader prepared by the Accelerator will utilize non-blocking host-to-device | |
| transfers, allowing for better overlap between dataloader communication and computation. Recommended that | |
| the prepared dataloader has `pin_memory` set to `True` to work properly. | |
| use_stateful_dataloader (`bool`, defaults to `False`): | |
| If set to `True`, the dataloader prepared by the Accelerator will be backed by | |
| [torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). | |
| This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed. | |
| """ | |
| split_batches: bool = field( | |
| default=False, | |
| metadata={ | |
| "help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If" | |
| " `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a" | |
| " round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set" | |
| " in your script multiplied by the number of processes." | |
| }, | |
| ) | |
| dispatch_batches: bool = field( | |
| default=None, | |
| metadata={ | |
| "help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process" | |
| " and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose" | |
| " underlying dataset is an `IterableDataset`, `False` otherwise." | |
| }, | |
| ) | |
| even_batches: bool = field( | |
| default=True, | |
| metadata={ | |
| "help": "If set to `True`, in cases where the total batch size across all processes does not exactly divide the" | |
| " dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among" | |
| " all workers." | |
| }, | |
| ) | |
| use_seedable_sampler: bool = field( | |
| default=False, | |
| metadata={ | |
| "help": "Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`])." | |
| "Ensures training results are fully reproducible using a different sampling technique. " | |
| "While seed-to-seed results may differ, on average the differences are negligible when using" | |
| "multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results." | |
| }, | |
| ) | |
| data_seed: int = field( | |
| default=None, | |
| metadata={ | |
| "help": "The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator" | |
| " will use the current default seed from torch." | |
| }, | |
| ) | |
| non_blocking: bool = field( | |
| default=False, | |
| metadata={ | |
| "help": "If set to `True`, the dataloader prepared by the Accelerator will utilize non-blocking host-to-device" | |
| " transfers, allowing for better overlap between dataloader communication and computation. Recommended that the" | |
| " prepared dataloader has `pin_memory` set to `True` to work properly." | |
| }, | |
| ) | |
| use_stateful_dataloader: bool = field( | |
| default=False, | |
| metadata={ | |
| "help": "If set to `True`, the dataloader prepared by the Accelerator will be backed by " | |
| "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed." | |
| }, | |
| ) | |
| class ProjectConfiguration: | |
| """ | |
| Configuration for the Accelerator object based on inner-project needs. | |
| Args: | |
| project_dir (`str`, defaults to `None`): | |
| A path to a directory for storing data. | |
| logging_dir (`str`, defaults to `None`): | |
| A path to a directory for storing logs of locally-compatible loggers. If None, defaults to `project_dir`. | |
| automatic_checkpoint_naming (`bool`, defaults to `False`): | |
| Whether saved states should be automatically iteratively named. | |
| total_limit (`int`, defaults to `None`): | |
| The maximum number of total saved states to keep. | |
| iteration (`int`, defaults to `0`): | |
| The current save iteration. | |
| save_on_each_node (`bool`, defaults to `False`): | |
| When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on | |
| the main one. | |
| """ | |
| project_dir: str = field(default=None, metadata={"help": "A path to a directory for storing data."}) | |
| logging_dir: str = field( | |
| default=None, | |
| metadata={ | |
| "help": "A path to a directory for storing logs of locally-compatible loggers. If None, defaults to `project_dir`." | |
| }, | |
| ) | |
| automatic_checkpoint_naming: bool = field( | |
| default=False, | |
| metadata={"help": "Whether saved states should be automatically iteratively named."}, | |
| ) | |
| total_limit: int = field( | |
| default=None, | |
| metadata={"help": "The maximum number of total saved states to keep."}, | |
| ) | |
| iteration: int = field( | |
| default=0, | |
| metadata={"help": "The current save iteration."}, | |
| ) | |
| save_on_each_node: bool = field( | |
| default=False, | |
| metadata={ | |
| "help": ( | |
| "When doing multi-node distributed training, whether to save models and checkpoints on each node, or" | |
| " only on the main one" | |
| ) | |
| }, | |
| ) | |
| def set_directories(self, project_dir: Optional[str] = None): | |
| "Sets `self.project_dir` and `self.logging_dir` to the appropriate values." | |
| self.project_dir = project_dir | |
| if self.logging_dir is None: | |
| self.logging_dir = project_dir | |
| def __post_init__(self): | |
| self.set_directories(self.project_dir) | |
| class GradientAccumulationPlugin(KwargsHandler): | |
| """ | |
| A plugin to configure gradient accumulation behavior. You can only pass one of `gradient_accumulation_plugin` or | |
| `gradient_accumulation_steps` to [`Accelerator`]. Passing both raises an error. | |
| Parameters: | |
| num_steps (`int`): | |
| The number of steps to accumulate gradients for. | |
| adjust_scheduler (`bool`, *optional*, defaults to `True`): | |
| Whether to adjust the scheduler steps to account for the number of steps being accumulated. Should be | |
| `True` if the used scheduler was not adjusted for gradient accumulation. | |
| sync_with_dataloader (`bool`, *optional*, defaults to `True`): | |
| Whether to synchronize setting the gradients when at the end of the dataloader. | |
| sync_each_batch (`bool`, *optional*): | |
| Whether to synchronize setting the gradients at each data batch. Setting to `True` may reduce memory | |
| requirements when using gradient accumulation with distributed training, at expense of speed. | |
| Example: | |
| ```python | |
| from accelerate.utils import GradientAccumulationPlugin | |
| gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2) | |
| accelerator = Accelerator(gradient_accumulation_plugin=gradient_accumulation_plugin) | |
| ``` | |
| """ | |
| num_steps: int = field( | |
| default=None, | |
| metadata={"help": "The number of steps to accumulate gradients for."}, | |
| ) | |
| adjust_scheduler: bool = field( | |
| default=True, | |
| metadata={ | |
| "help": "Whether to adjust the scheduler steps to account for the number of steps being accumulated. Should be `True` if the used scheduler was not adjusted for gradient accumulation." | |
| }, | |
| ) | |
| sync_with_dataloader: bool = field( | |
| default=True, | |
| metadata={ | |
| "help": "Whether to synchronize setting the gradients when at the end of the dataloader. Should only be set to `False` if you know what you're doing." | |
| }, | |
| ) | |
| sync_each_batch: bool = field( | |
| default=False, | |
| metadata={ | |
| "help": "Whether to synchronize setting the gradients at each data batch. Setting to `True` may reduce memory requirements when using gradient accumulation with distributed training, at expense of speed." | |
| }, | |
| ) | |
| class TorchDynamoPlugin(KwargsHandler): | |
| """ | |
| This plugin is used to compile a model with PyTorch 2.0 | |
| Args: | |
| backend (`DynamoBackend`, defaults to `None`): | |
| A valid Dynamo backend. See https://pytorch.org/docs/stable/torch.compiler.html for more details. | |
| mode (`str`, defaults to `None`): | |
| Possible options are 'default', 'reduce-overhead' or 'max-autotune'. | |
| fullgraph (`bool`, defaults to `None`): | |
| Whether it is ok to break model into several subgraphs. | |
| dynamic (`bool`, defaults to `None`): | |
| Whether to use dynamic shape for tracing. | |
| options (`Any`, defaults to `None`): | |
| A dictionary of options to pass to the backend. | |
| disable (`bool`, defaults to `False`): | |
| Turn torch.compile() into a no-op for testing | |
| use_regional_compilation (`bool`, defaults to `None`): | |
| Use it to reduce the cold start compilation time of torch.compile() by targeting repeated blocks of the | |
| same class and compiling them sequentially to hit the compiler's cache. For example, in `GPT2LMHeadModel`, | |
| the repeated block/class is `GPT2Block`, and can be accessed as `model.transformer.h[0]`. The rest of the | |
| model (e.g model.lm_head) is compiled separately. | |
| """ | |
| backend: DynamoBackend = field( | |
| default=None, | |
| metadata={"help": f"Possible options are {[b.value.lower() for b in DynamoBackend]}"}, | |
| ) | |
| mode: str = field( | |
| default=None, | |
| metadata={"help": "Possible options are 'default', 'reduce-overhead' or 'max-autotune'"}, | |
| ) | |
| fullgraph: bool = field( | |
| default=None, | |
| metadata={"help": "Whether it is ok to break model into several subgraphs"}, | |
| ) | |
| dynamic: bool = field(default=None, metadata={"help": "Whether to use dynamic shape for tracing"}) | |
| options: Any = field( | |
| default=None, | |
| metadata={"help": "A dictionary of options to pass to the backend."}, | |
| ) | |
| disable: bool = field( | |
| default=False, | |
| metadata={"help": "Turn torch.compile() into a no-op for testing"}, | |
| ) | |
| use_regional_compilation: bool = field( | |
| default=None, | |
| metadata={ | |
| "help": ( | |
| # https://pytorch.org/tutorials/recipes/regional_compilation.html | |
| "Use it to reduce the cold start compilation time of torch.compile() by targeting repeated " | |
| "blocks of the same class and compiling them sequentially to hit the compiler's cache. For " | |
| "example, in `GPT2LMHeadModel`, the repeated block/class is `GPT2Block`, and can be accessed " | |
| "as `model.transformer.h[0]`. The rest of the model (e.g model.lm_head) is compiled separately." | |
| ) | |
| }, | |
| ) | |
| def __post_init__(self): | |
| prefix = "ACCELERATE_DYNAMO_" | |
| if self.backend is None: | |
| self.backend = os.environ.get(prefix + "BACKEND", "no") | |
| self.backend = DynamoBackend(self.backend.upper()) | |
| if self.mode is None: | |
| self.mode = os.environ.get(prefix + "MODE", "default") | |
| if self.fullgraph is None: | |
| self.fullgraph = str_to_bool(os.environ.get(prefix + "USE_FULLGRAPH", "False")) == 1 | |
| if self.use_regional_compilation is None: | |
| self.use_regional_compilation = ( | |
| str_to_bool(os.environ.get(prefix + "USE_REGIONAL_COMPILATION", "False")) == 1 | |
| ) | |
| if self.dynamic is None and os.environ.get(prefix + "USE_DYNAMIC", None) is not None: | |
| self.dynamic = str_to_bool(os.environ.get(prefix + "USE_DYNAMIC", "False")) == 1 | |
| def to_dict(self): | |
| dynamo_config = copy.deepcopy(self.__dict__) | |
| dynamo_config["backend"] = dynamo_config["backend"].value.lower() | |
| return dynamo_config | |
| def to_kwargs(self): | |
| kwargs = super().to_kwargs() | |
| kwargs.pop("use_regional_compilation", None) | |
| return kwargs | |
| class DeepSpeedPlugin: | |
| """ | |
| This plugin is used to integrate DeepSpeed. | |
| Args: | |
| hf_ds_config (`Any`, defaults to `None`): | |
| Path to DeepSpeed config file or dict or an object of class `accelerate.utils.deepspeed.HfDeepSpeedConfig`. | |
| gradient_accumulation_steps (`int`, defaults to `None`): | |
| Number of steps to accumulate gradients before updating optimizer states. If not set, will use the value | |
| from the `Accelerator` directly. | |
| gradient_clipping (`float`, defaults to `None`): | |
| Enable gradient clipping with value. | |
| zero_stage (`int`, defaults to `None`): | |
| Possible options are 0, 1, 2, 3. Default will be taken from environment variable. | |
| is_train_batch_min (`bool`, defaults to `True`): | |
| If both train & eval dataloaders are specified, this will decide the `train_batch_size`. | |
| offload_optimizer_device (`str`, defaults to `None`): | |
| Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3. | |
| offload_param_device (`str`, defaults to `None`): | |
| Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3. | |
| offload_optimizer_nvme_path (`str`, defaults to `None`): | |
| Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3. | |
| offload_param_nvme_path (`str`, defaults to `None`): | |
| Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3. | |
| zero3_init_flag (`bool`, defaults to `None`): | |
| Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3. | |
| zero3_save_16bit_model (`bool`, defaults to `None`): | |
| Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3. | |
| transformer_moe_cls_names (`str`, defaults to `None`): | |
| Comma-separated list of Transformers MoE layer class names (case-sensitive). For example, | |
| `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention`, `JetMoEBlock`, etc. | |
| enable_msamp (`bool`, defaults to `None`): | |
| Flag to indicate whether to enable MS-AMP backend for FP8 training. | |
| msasmp_opt_level (`Optional[Literal["O1", "O2"]]`, defaults to `None`): | |
| Optimization level for MS-AMP (defaults to 'O1'). Only applicable if `enable_msamp` is True. Should be one | |
| of ['O1' or 'O2']. | |
| """ | |
| hf_ds_config: Any = field( | |
| default=None, | |
| metadata={ | |
| "help": "path to DeepSpeed config file or dict or an object of class `accelerate.utils.deepspeed.HfDeepSpeedConfig`." | |
| }, | |
| ) | |
| gradient_accumulation_steps: int = field( | |
| default=None, | |
| metadata={ | |
| "help": "Number of steps to accumulate gradients before updating optimizer states. If not set, will use the value from the `Accelerator` directly." | |
| }, | |
| ) | |
| gradient_clipping: float = field(default=None, metadata={"help": "Enable gradient clipping with value"}) | |
| zero_stage: int = field( | |
| default=None, | |
| metadata={"help": "Possible options are 0,1,2,3; Default will be taken from environment variable"}, | |
| ) | |
| is_train_batch_min: bool = field( | |
| default=True, | |
| metadata={"help": "If both train & eval dataloaders are specified, this will decide the train_batch_size"}, | |
| ) | |
| offload_optimizer_device: str = field( | |
| default=None, | |
| metadata={"help": "Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3."}, | |
| ) | |
| offload_param_device: str = field( | |
| default=None, | |
| metadata={"help": "Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3."}, | |
| ) | |
| offload_optimizer_nvme_path: str = field( | |
| default=None, | |
| metadata={"help": "Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."}, | |
| ) | |
| offload_param_nvme_path: str = field( | |
| default=None, | |
| metadata={"help": "Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."}, | |
| ) | |
| zero3_init_flag: bool = field( | |
| default=None, | |
| metadata={ | |
| "help": "Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models." | |
| "Only applicable with ZeRO Stage-3." | |
| }, | |
| ) | |
| zero3_save_16bit_model: bool = field( | |
| default=None, | |
| metadata={"help": "Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3."}, | |
| ) | |
| transformer_moe_cls_names: str = field( | |
| default=None, | |
| metadata={ | |
| "help": "comma-separated list of transformers MoE layer class names (case-sensitive), e.g : " | |
| " `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ..." | |
| }, | |
| ) | |
| enable_msamp: bool = field( | |
| default=None, | |
| metadata={"help": "Flag to indicate whether to enable MS-AMP backend for FP8 training."}, | |
| ) | |
| msamp_opt_level: Optional[Literal["O1", "O2"]] = field( | |
| default=None, | |
| metadata={ | |
| "help": "Optimization level for MS-AMP (defaults to 'O1'). Only applicable if `enable_msamp` is True. Should be one of ['O1' or 'O2']." | |
| }, | |
| ) | |
| def __post_init__(self): | |
| from .deepspeed import HfDeepSpeedConfig | |
| if self.gradient_accumulation_steps is None: | |
| gas = os.environ.get("ACCELERATE_GRADIENT_ACCUMULATION_STEPS", "auto") | |
| self.gradient_accumulation_steps = int(gas) if gas.isdigit() else gas | |
| if self.gradient_clipping is None: | |
| gradient_clipping = os.environ.get("ACCELERATE_GRADIENT_CLIPPING", "auto") | |
| self.gradient_clipping = gradient_clipping if gradient_clipping == "auto" else float(gradient_clipping) | |
| if self.zero_stage is None: | |
| self.zero_stage = int(os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE", 2)) | |
| if self.offload_optimizer_device is None: | |
| self.offload_optimizer_device = os.environ.get("ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE", "none") | |
| if self.offload_param_device is None: | |
| self.offload_param_device = os.environ.get("ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE", "none") | |
| if self.offload_optimizer_nvme_path is None: | |
| self.offload_optimizer_nvme_path = os.environ.get( | |
| "ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_NVME_PATH", "none" | |
| ) | |
| if self.offload_param_nvme_path is None: | |
| self.offload_param_nvme_path = os.environ.get("ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_NVME_PATH", "none") | |
| if self.zero3_save_16bit_model is None: | |
| self.zero3_save_16bit_model = ( | |
| os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", "false").lower() == "true" | |
| ) | |
| if self.enable_msamp is None: | |
| self.enable_msamp = os.environ.get("ACCELERATE_FP8_BACKEND", None) == "MSAMP" | |
| if self.msamp_opt_level is None: | |
| self.msamp_opt_level = os.environ.get("ACCELERATE_FP8_OPT_LEVEL", "O1") | |
| if self.hf_ds_config is None: | |
| self.hf_ds_config = os.environ.get("ACCELERATE_DEEPSPEED_CONFIG_FILE", "none") | |
| if ( | |
| isinstance(self.hf_ds_config, dict) | |
| or (isinstance(self.hf_ds_config, str) and self.hf_ds_config != "none") | |
| or isinstance(self.hf_ds_config, HfDeepSpeedConfig) | |
| ): | |
| if not isinstance(self.hf_ds_config, HfDeepSpeedConfig): | |
| self.hf_ds_config = HfDeepSpeedConfig(self.hf_ds_config) | |
| if "gradient_accumulation_steps" not in self.hf_ds_config.config: | |
| self.hf_ds_config.config["gradient_accumulation_steps"] = 1 | |
| if "zero_optimization" not in self.hf_ds_config.config: | |
| raise ValueError("Please specify the ZeRO optimization config in the DeepSpeed config.") | |
| self._deepspeed_config_checks() | |
| plugin_to_config_mapping = { | |
| "gradient_accumulation_steps": "gradient_accumulation_steps", | |
| "gradient_clipping": "gradient_clipping", | |
| "zero_stage": "zero_optimization.stage", | |
| "offload_optimizer_device": "zero_optimization.offload_optimizer.device", | |
| "offload_param_device": "zero_optimization.offload_param.device", | |
| "offload_param_nvme_path": "zero_optimization.offload_param.nvme_path", | |
| "offload_optimizer_nvme_path": "zero_optimization.offload_optimizer.nvme_path", | |
| "zero3_save_16bit_model": "zero_optimization.stage3_gather_16bit_weights_on_model_save", | |
| } | |
| kwargs = {v: getattr(self, k) for k, v in plugin_to_config_mapping.items() if getattr(self, k) is not None} | |
| for key in kwargs.keys(): | |
| self.fill_match(key, **kwargs, must_match=False) | |
| self.hf_ds_config.set_stage_and_offload() | |
| # filling the missing values in the class attributes from the DeepSpeed config | |
| # when using the DeepSpeed config file. | |
| for key, value in plugin_to_config_mapping.items(): | |
| config_value = self.hf_ds_config.get_value(value) | |
| if config_value is not None and config_value != "auto": | |
| setattr(self, key, config_value) | |
| else: | |
| config = { | |
| "train_batch_size": "auto", | |
| "train_micro_batch_size_per_gpu": "auto", | |
| "gradient_accumulation_steps": self.gradient_accumulation_steps, | |
| "zero_optimization": { | |
| "stage": self.zero_stage, | |
| "offload_optimizer": { | |
| "device": self.offload_optimizer_device, | |
| "nvme_path": ( | |
| self.offload_optimizer_nvme_path if self.offload_optimizer_device == "nvme" else None | |
| ), | |
| }, | |
| "offload_param": { | |
| "device": self.offload_param_device, | |
| "nvme_path": (self.offload_param_nvme_path if self.offload_param_device == "nvme" else None), | |
| }, | |
| "stage3_gather_16bit_weights_on_model_save": self.zero3_save_16bit_model, | |
| }, | |
| } | |
| if self.gradient_clipping: | |
| config["gradient_clipping"] = self.gradient_clipping | |
| self.hf_ds_config = HfDeepSpeedConfig(config) | |
| self.deepspeed_config = self.hf_ds_config.config | |
| self.deepspeed_config["steps_per_print"] = float("inf") # this will stop deepspeed from logging @ stdout | |
| if self.zero3_init_flag is None: | |
| self.zero3_init_flag = ( | |
| str_to_bool( | |
| os.environ.get( | |
| "ACCELERATE_DEEPSPEED_ZERO3_INIT", | |
| str(self.hf_ds_config.is_zero3()), | |
| ) | |
| ) | |
| == 1 | |
| ) | |
| if self.zero3_init_flag and not self.hf_ds_config.is_zero3(): | |
| warnings.warn("DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.") | |
| self.zero3_init_flag = False | |
| # NOTE: Set to False by default, will be set to `True` automatically if it's the first plugin passed | |
| # to the `Accelerator`'s `deepspeed_plugin` param, *or* `AcceleratorState().enable_deepspeed_plugin(plugin_key)` is manually called | |
| self._set_selected(False) | |
| # Ignore if it's already set | |
| if self.enable_msamp and "msamp" not in self.deepspeed_config: | |
| if self.zero_stage == 3: | |
| raise NotImplementedError( | |
| "MS-AMP is not supported for ZeRO Stage 3. Please use ZeRO Stage 0, 1, or 2 instead." | |
| ) | |
| if self.msamp_opt_level not in ["O1", "O2"]: | |
| raise ValueError("Invalid optimization level for MS-AMP. Please use one of ['O1' or'O2'].") | |
| self.deepspeed_config["msamp"] = { | |
| "enabled": True, | |
| "opt_level": self.msamp_opt_level, | |
| } | |
| def fill_match(self, ds_key_long, mismatches=None, must_match=True, **kwargs): | |
| mismatches = [] if mismatches is None else mismatches | |
| config, ds_key = self.hf_ds_config.find_config_node(ds_key_long) | |
| if config is None: | |
| return | |
| if config.get(ds_key) == "auto": | |
| if ds_key_long in kwargs: | |
| config[ds_key] = kwargs[ds_key_long] | |
| return | |
| else: | |
| raise ValueError( | |
| f"`{ds_key_long}` not found in kwargs. " | |
| f"Please specify `{ds_key_long}` without `auto` (set to correct value) in the DeepSpeed config file or " | |
| "pass it in kwargs." | |
| ) | |
| if not must_match: | |
| return | |
| ds_val = config.get(ds_key) | |
| if ds_val is not None and ds_key_long in kwargs: | |
| if ds_val != kwargs[ds_key_long]: | |
| mismatches.append(f"- ds {ds_key_long}={ds_val} vs arg {ds_key_long}={kwargs[ds_key_long]}") | |
| def is_auto(self, ds_key_long): | |
| val = self.hf_ds_config.get_value(ds_key_long) | |
| if val is None: | |
| return False | |
| else: | |
| return val == "auto" | |
| def get_value(self, ds_key_long, default=None): | |
| return self.hf_ds_config.get_value(ds_key_long, default) | |
| def deepspeed_config_process(self, prefix="", mismatches=None, config=None, must_match=True, **kwargs): | |
| """Process the DeepSpeed config with the values from the kwargs.""" | |
| mismatches = [] if mismatches is None else mismatches | |
| if config is None: | |
| config = self.deepspeed_config | |
| for key, value in config.items(): | |
| if isinstance(value, dict): | |
| self.deepspeed_config_process( | |
| prefix=prefix + key + ".", | |
| mismatches=mismatches, | |
| config=value, | |
| must_match=must_match, | |
| **kwargs, | |
| ) | |
| else: | |
| self.fill_match(prefix + key, mismatches, must_match=must_match, **kwargs) | |
| if len(mismatches) > 0 and prefix == "": | |
| mismatches_msg = "\n".join(mismatches) | |
| raise ValueError( | |
| "Please correct the following DeepSpeed config values that mismatch kwargs " | |
| f" values:\n{mismatches_msg}\nThe easiest method is to set these DeepSpeed config values to 'auto'." | |
| ) | |
| def set_mixed_precision(self, mixed_precision): | |
| ds_config = self.deepspeed_config | |
| kwargs = { | |
| "fp16.enabled": mixed_precision == "fp16", | |
| # When training in fp8, we still rely on bf16 autocast for the core mixed precision | |
| "bf16.enabled": mixed_precision in ("bf16", "fp8"), | |
| } | |
| if mixed_precision == "fp16": | |
| if "fp16" not in ds_config: | |
| ds_config["fp16"] = {"enabled": True, "auto_cast": True} | |
| elif mixed_precision in ("bf16", "fp8"): | |
| if "bf16" not in ds_config: | |
| ds_config["bf16"] = {"enabled": True} | |
| if mixed_precision == "fp8" and self.enable_msamp: | |
| if "msamp" not in ds_config: | |
| ds_config["msamp"] = { | |
| "enabled": True, | |
| "opt_level": self.msamp_opt_level, | |
| } | |
| if mixed_precision != "no": | |
| diff_dtype = "bf16" if mixed_precision == "fp16" else "fp16" | |
| if str(ds_config.get(diff_dtype, {}).get("enabled", "False")).lower() == "true": | |
| raise ValueError( | |
| f"`--mixed_precision` arg cannot be set to `{mixed_precision}` when `{diff_dtype}` is set in the DeepSpeed config file." | |
| ) | |
| for dtype in ["fp16", "bf16"]: | |
| if dtype not in ds_config: | |
| ds_config[dtype] = {"enabled": False} | |
| self.fill_match("fp16.enabled", must_match=False, **kwargs) | |
| self.fill_match("bf16.enabled", must_match=False, **kwargs) | |
| def set_deepspeed_weakref(self): | |
| from .imports import is_transformers_available | |
| ds_config = copy.deepcopy(self.deepspeed_config) | |
| if self.zero3_init_flag: | |
| if not is_transformers_available(): | |
| raise Exception( | |
| "When `zero3_init_flag` is set, it requires Transformers to be installed. " | |
| "Please run `pip install transformers`." | |
| ) | |
| if "gradient_accumulation_steps" not in ds_config or ds_config["gradient_accumulation_steps"] == "auto": | |
| ds_config["gradient_accumulation_steps"] = 1 | |
| if "train_micro_batch_size_per_gpu" not in ds_config or ds_config["train_micro_batch_size_per_gpu"] == "auto": | |
| ds_config["train_micro_batch_size_per_gpu"] = 1 | |
| if ds_config.get("train_batch_size", None) == "auto": | |
| del ds_config["train_batch_size"] | |
| if compare_versions("transformers", "<", "4.46"): | |
| from transformers.deepspeed import ( | |
| HfDeepSpeedConfig, | |
| unset_hf_deepspeed_config, | |
| ) | |
| else: | |
| from transformers.integrations import ( | |
| HfDeepSpeedConfig, | |
| unset_hf_deepspeed_config, | |
| ) | |
| unset_hf_deepspeed_config() | |
| self.dschf = HfDeepSpeedConfig(ds_config) # keep this object alive # noqa | |
| def is_zero3_init_enabled(self): | |
| return self.zero3_init_flag | |
| def zero3_init_context_manager(self, enable=False): | |
| old = self.zero3_init_flag | |
| if old == enable: | |
| yield | |
| else: | |
| self.zero3_init_flag = enable | |
| self.dschf = None | |
| self.set_deepspeed_weakref() | |
| yield | |
| self.zero3_init_flag = old | |
| self.dschf = None | |
| self.set_deepspeed_weakref() | |
| def _deepspeed_config_checks(self): | |
| env_variable_names_to_ignore = [ | |
| "ACCELERATE_GRADIENT_ACCUMULATION_STEPS", | |
| "ACCELERATE_GRADIENT_CLIPPING", | |
| "ACCELERATE_DEEPSPEED_ZERO_STAGE", | |
| "ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE", | |
| "ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE", | |
| "ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_NVME_PATH", | |
| "ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_NVME_PATH", | |
| "ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", | |
| "ACCELERATE_MIXED_PRECISION", | |
| ] | |
| env_variable_names_to_ignore = [ | |
| name.replace("ACCELERATE_", "").replace("DEEPSPEED_", "").lower() for name in env_variable_names_to_ignore | |
| ] | |
| deepspeed_fields_from_accelerate_config = os.environ.get("ACCELERATE_CONFIG_DS_FIELDS", "").split(",") | |
| if any(name in env_variable_names_to_ignore for name in deepspeed_fields_from_accelerate_config): | |
| raise ValueError( | |
| f"When using `deepspeed_config_file`, the following accelerate config variables will be ignored: {env_variable_names_to_ignore}.\n" | |
| "Please specify them appropriately in the DeepSpeed config file.\n" | |
| "If you are using an accelerate config file, remove others config variables mentioned in the above specified list.\n" | |
| "The easiest method is to create a new config following the questionnaire via `accelerate config`.\n" | |
| "It will only ask for the necessary config variables when using `deepspeed_config_file`." | |
| ) | |
| def set_moe_leaf_modules(self, model): | |
| if self.transformer_moe_cls_names is None: | |
| self.transformer_moe_cls_names = os.environ.get("ACCELERATE_DEEPSPEED_MOE_LAYER_CLS_NAMES", None) | |
| if self.transformer_moe_cls_names is not None: | |
| if compare_versions("deepspeed", "<", "0.14.0"): | |
| raise ImportError("DeepSpeed version must be >= 0.14.0 to use MOE support. Please update DeepSpeed.") | |
| from deepspeed.utils import set_z3_leaf_modules | |
| class_names = self.transformer_moe_cls_names.split(",") | |
| transformer_moe_cls = [] | |
| for layer_class in class_names: | |
| transformer_cls = get_module_class_from_name(model, layer_class) | |
| if transformer_cls is None: | |
| raise Exception( | |
| f"Could not find a transformer layer class called '{layer_class}' to wrap in the model." | |
| ) | |
| else: | |
| transformer_moe_cls.append(transformer_cls) | |
| set_z3_leaf_modules(model, transformer_moe_cls) # z3_leaf | |
| def select(self, _from_accelerator_state: bool = False): | |
| """ | |
| Sets the HfDeepSpeedWeakref to use the current deepspeed plugin configuration | |
| """ | |
| if not _from_accelerator_state: | |
| raise ValueError( | |
| "A `DeepSpeedPlugin` object must be enabled manually by calling `AcceleratorState().enable_deepspeed_plugin(plugin_key)`." | |
| ) | |
| self.set_deepspeed_weakref() | |
| self._set_selected(True) | |
| def _unselect(self): | |
| self._set_selected(False) | |
| def _set_selected(self, value: bool): | |
| """ | |
| Private setter for the 'enabled' attribute. | |
| """ | |
| self._selected = value | |
| def selected(self): | |
| return self._selected | |
| def selected(self, value): | |
| raise NotImplementedError( | |
| "'enabled' can only be set through calling 'AcceleratorState().enable_deepspeed_plugin(key)'." | |
| ) | |
| class FullyShardedDataParallelPlugin: | |
| """ | |
| This plugin is used to enable fully sharded data parallelism. | |
| Args: | |
| fsdp_version (`int`, defaults to `1`): | |
| The version of FSDP to use. Defaults to 1. If set to 2, launcher expects the config to be converted to | |
| FSDP2 format. | |
| sharding_strategy (`Union[str, torch.distributed.fsdp.ShardingStrategy]`, defaults to `'FULL_SHARD'`): | |
| Sharding strategy to use. Should be either a `str` or an instance of | |
| `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Is deprecated in favor of | |
| `reshard_after_forward`. | |
| reshard_after_forward (`Union[str, torch.distributed.fsdp.ShardingStrategy, bool]`, defaults to `'FULL_SHARD'` for `fsdp_version=1` and `True` for `fsdp_version=2`): | |
| Sharding strategy to use. Should be a bool if `fsdp_version` is set to 2 else a `str` or an instance of | |
| `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. | |
| backward_prefetch (`Union[str, torch.distributed.fsdp.BackwardPrefetch]`, defaults to `'NO_PREFETCH'`): | |
| Backward prefetch strategy to use. Should be either a `str` or an instance of | |
| `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`. | |
| mixed_precision_policy (`Optional[Union[dict, str, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`): | |
| A config to enable mixed precision training with FullyShardedDataParallel. If passing in a `dict`, it | |
| should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`, can be an instance of | |
| `torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2. If passing in a `str`, it | |
| should be one of the following values: fp8, fp16, bf16, fp32, and used to set `param_dtype`, | |
| `reduce_dtype`, and `buffer_dtype`. | |
| auto_wrap_policy (`Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to `NO_WRAP`): | |
| A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one | |
| of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. See | |
| `torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like. | |
| cpu_offload (`Union[bool, torch.distributed.fsdp.CPUOffload, torch.distributed.fsdp.CPUOffloadPolicy]`, defaults to `False`): | |
| Whether to offload parameters to CPU. Should be either a `bool` or an instance of | |
| `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or | |
| `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. | |
| ignored_modules (`Optional[Union[Iterable[torch.nn.Module], str]]`, defaults to `None`): | |
| A list of modules to ignore when wrapping with FSDP. When passing a string, will match the modules by name | |
| using regex fullmatch. If `fsdp_version` is set to 2, the modules are converted to parameters and used. | |
| state_dict_type (`Union[str, torch.distributed.fsdp.StateDictType]`, defaults to `'FULL_STATE_DICT'`): | |
| State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or | |
| `sharded_state_dict`. | |
| state_dict_config (`Optional[Union[torch.distributed.fsdp.FullStateDictConfig, torch.distributed.fsdp.ShardedStateDictConfig]`, defaults to `None`): | |
| State dict config to use. Is determined based on the `state_dict_type` if not passed in. | |
| optim_state_dict_config (`Optional[Union[torch.distributed.fsdp.FullOptimStateDictConfig, torch.distributed.fsdp.ShardedOptimStateDictConfig]`, defaults to `None`): | |
| Optim state dict config to use. Is determined based on the `state_dict_type` if not passed in. | |
| limit_all_gathers (`bool`, defaults to `True`): | |
| Whether to have FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. This | |
| bool only affects the sharded strategies that schedule all-gathers. Enabling this can help lower the number | |
| of CUDA malloc retries. | |
| use_orig_params (`bool`, defaults to `False`): | |
| Whether to use the original parameters for the optimizer. | |
| param_init_fn (`Optional[Callable[[torch.nn.Module], None]`, defaults to `None`): | |
| A `Callable[torch.nn.Module] -> None` that specifies how modules that are currently on the meta device | |
| should be initialized onto an actual device. Only applicable when `sync_module_states` is `True`. By | |
| default is a `lambda` which calls `to_empty` on the module. | |
| sync_module_states (`bool`, defaults to `False`): | |
| Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 to ensure they | |
| are the same across all ranks after initialization. Defaults to `False` unless `cpu_ram_efficient_loading` | |
| is `True`, then will be forcibly enabled. | |
| forward_prefetch (`bool`, defaults to `False`): | |
| Whether to have FSDP explicitly prefetches the next upcoming all-gather while executing in the forward | |
| pass. only use with Static graphs. | |
| activation_checkpointing (`bool`, defaults to `False`): | |
| A technique to reduce memory usage by clearing activations of certain layers and recomputing them during a | |
| backward pass. Effectively, this trades extra computation time for reduced memory usage. | |
| cpu_ram_efficient_loading (`bool`, defaults to `None`): | |
| If True, only the first process loads the pretrained model checkoint while all other processes have empty | |
| weights. Only applicable for Transformers. When using this, `sync_module_states` needs to be `True`. | |
| transformer_cls_names_to_wrap (`Optional[List[str]]`, defaults to `None`): | |
| A list of transformer layer class names to wrap. Only applicable when `auto_wrap_policy` is | |
| `transformer_based_wrap`. | |
| min_num_params (`Optional[int]`, defaults to `None`): | |
| The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy` | |
| is `size_based_wrap`. | |
| """ | |
| fsdp_version: int = field( | |
| default=None, | |
| metadata={ | |
| "help": "The version of FSDP to use. Defaults to 1. If set to 2, launcher expects the config to be converted to FSDP2 format." | |
| }, | |
| ) | |
| sharding_strategy: Union[str, "torch.distributed.fsdp.ShardingStrategy"] = field( | |
| default=None, | |
| metadata={ | |
| "help": "Sharding strategy to use. Should be either a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Defaults to 'FULL_SHARD'. Is deprecated in favor of `reshard_after_forward` " | |
| }, | |
| ) | |
| reshard_after_forward: Union[str, "torch.distributed.fsdp.ShardingStrategy", bool] = field( | |
| default=None, | |
| metadata={ | |
| "help": "Sharding strategy to use. Should be a bool if `fsdp_version` is set to 2 else a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Defaults to 'FULL_SHARD'" | |
| }, | |
| ) | |
| backward_prefetch: Optional[Union[str, "torch.distributed.fsdp.BackwardPrefetch"]] = field( | |
| default=None, | |
| metadata={ | |
| "help": "Backward prefetch strategy to use. Should be either a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`. Defaults to 'NO_PREFETCH'. This becomes obsolete in FSDP2." | |
| }, | |
| ) | |
| mixed_precision_policy: Optional[ | |
| Union[ | |
| dict, | |
| str, | |
| "torch.distributed.fsdp.MixedPrecision", | |
| "torch.distributed.fsdp.MixedPrecisionPolicy", | |
| ] | |
| ] = field( | |
| default=None, | |
| metadata={ | |
| "help": "A config to enable mixed precision training with FullyShardedDataParallel. " | |
| "If passing in a `dict`, it should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`." | |
| "Can also be an instance of `torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2." | |
| }, | |
| ) | |
| auto_wrap_policy: Optional[Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]] = ( | |
| field( | |
| default=None, | |
| metadata={ | |
| "help": "A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. " | |
| "Defaults to `NO_WRAP`. See `torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like" | |
| }, | |
| ) | |
| ) | |
| cpu_offload: Union[ | |
| bool, | |
| "torch.distributed.fsdp.CPUOffload", | |
| "torch.distributed.fsdp.CPUOffloadPolicy", | |
| ] = field( | |
| default=None, | |
| metadata={ | |
| "help": "Whether to offload parameters to CPU. Should be either a `bool` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. Defaults to `False`" | |
| }, | |
| ) | |
| ignored_modules: Optional[Union[Iterable[torch.nn.Module], str]] = field( | |
| default=None, | |
| metadata={"help": "A list of modules to ignore when wrapping with FSDP."}, | |
| ) | |
| state_dict_type: Union[str, "torch.distributed.fsdp.StateDictType"] = field( | |
| default=None, | |
| metadata={ | |
| "help": "State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or `sharded_state_dict`. Defaults to `FULL_STATE_DICT`" | |
| }, | |
| ) | |
| state_dict_config: Optional[ | |
| Union[ | |
| "torch.distributed.fsdp.FullStateDictConfig", | |
| "torch.distributed.fsdp.ShardedStateDictConfig", | |
| ] | |
| ] = field( | |
| default=None, | |
| metadata={"help": "State dict config to use. Is determined based on the `state_dict_type` if not passed in."}, | |
| ) | |
| optim_state_dict_config: Optional[ | |
| Union[ | |
| "torch.distributed.fsdp.FullOptimStateDictConfig", | |
| "torch.distributed.fsdp.ShardedOptimStateDictConfig", | |
| ] | |
| ] = field( | |
| default=None, | |
| metadata={ | |
| "help": "Optim state dict config to use. Is determined based on the `state_dict_type` if not passed in." | |
| }, | |
| ) | |
| limit_all_gathers: bool = field( | |
| default=True, | |
| metadata={ | |
| "help": "Whether to have FSDP explicitly synchronizes the CPU thread to prevent " | |
| "too many in-flight all-gathers. This bool only affects the sharded strategies that schedule all-gathers. " | |
| "Enabling this can help lower the number of CUDA malloc retries." | |
| }, | |
| ) | |
| use_orig_params: Optional[bool] = field( | |
| default=None, | |
| metadata={ | |
| "help": "Whether to use the original parameters for the optimizer. Defaults to `False`. This becomes obsolete in FSDP2." | |
| }, | |
| ) | |
| param_init_fn: Optional[Callable[[torch.nn.Module], None]] = field( | |
| default=None, | |
| metadata={ | |
| "help": "A Callable[torch.nn.Module] -> None that specifies how modules " | |
| "that are currently on the meta device should be initialized onto an actual device. " | |
| "Only applicable when `sync_module_states` is `True`. By default is a `lambda` which calls `to_empty` on the module." | |
| }, | |
| ) | |
| sync_module_states: Optional[bool] = field( | |
| default=None, | |
| metadata={ | |
| "help": "Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 " | |
| "to ensure they are the same across all ranks after initialization. Defaults to `False` unless " | |
| "`cpu_ram_efficient_loading` is `True`, then will be forcibly enabled. This becomes obsolete in FSDP2." | |
| }, | |
| ) | |
| forward_prefetch: bool = field( | |
| default=None, | |
| metadata={ | |
| "help": "Whether to have FSDP explicitly prefetches the next upcoming " | |
| "all-gather while executing in the forward pass. only use with Static graphs. Defaults to `False`" | |
| }, | |
| ) | |
| activation_checkpointing: bool = field( | |
| default=None, | |
| metadata={ | |
| "help": "A technique to reduce memory usage by clearing activations of " | |
| "certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time " | |
| "for reduced memory usage. Defaults to `False`" | |
| }, | |
| ) | |
| cpu_ram_efficient_loading: bool = field( | |
| default=None, | |
| metadata={ | |
| "help": "If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. " | |
| "Only applicable for 🤗 Transformers. When using this, `sync_module_states` needs to be `True`. Defaults to `False`." | |
| }, | |
| ) | |
| transformer_cls_names_to_wrap: Optional[list[str]] = field( | |
| default=None, | |
| metadata={ | |
| "help": "A list of transformer layer class names to wrap. Only applicable when `auto_wrap_policy` is `transformer_based_wrap`." | |
| }, | |
| ) | |
| min_num_params: Optional[int] = field( | |
| default=None, | |
| metadata={ | |
| "help": "The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy` is `size_based_wrap`." | |
| }, | |
| ) | |
| def __post_init__(self): | |
| from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy | |
| _fsdp2_warnings = set() | |
| env_prefix = "FSDP_" | |
| # Strategy: By default we should always assume that values are passed in, else we check the environment variables | |
| if self.fsdp_version is None: | |
| self.fsdp_version = int(os.environ.get(env_prefix + "VERSION", "1")) | |
| if self.fsdp_version == 2: | |
| if not is_torch_version(">=", FSDP2_PYTORCH_VERSION): | |
| raise ImportError(f"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}") | |
| if self.sharding_strategy is not None: | |
| # We cannot properly detect all of the cases, as by default `args.fsdp_sharding_strategy` is set to `fully_shard` | |
| # Therefore we issue a warning only if the user has explicitly set it inside their plugin | |
| _fsdp2_warnings.add( | |
| "sharding_strategy is deprecated in favor of reshard_after_forward. " | |
| "This will be removed in a future version of Accelerate." | |
| ) | |
| if self.fsdp_version == 1: | |
| if self.sharding_strategy is None: | |
| self.sharding_strategy = os.environ.get(env_prefix + "SHARDING_STRATEGY", "FULL_SHARD") | |
| if isinstance(self.sharding_strategy, str): | |
| if self.sharding_strategy.upper() in FSDP_SHARDING_STRATEGY: | |
| self.sharding_strategy = FSDP_SHARDING_STRATEGY.index(self.sharding_strategy.upper()) + 1 | |
| if isinstance(self.sharding_strategy, int) or self.sharding_strategy.isdigit(): | |
| self.sharding_strategy = ShardingStrategy(int(self.sharding_strategy)) | |
| else: | |
| self.sharding_strategy = ShardingStrategy[self.sharding_strategy.upper()] | |
| # Fallback to `reshard_after_forward` in FSDP1 if `sharding_strategy` is not set | |
| if self.reshard_after_forward is None and self.sharding_strategy is None: | |
| reshard_after_forward = os.environ.get( | |
| env_prefix + "RESHARD_AFTER_FORWARD", | |
| "true" if self.fsdp_version == 2 else "FULL_SHARD", | |
| ) | |
| if self.fsdp_version == 2: | |
| self.reshard_after_forward = str_to_bool(reshard_after_forward.lower(), to_bool=True) | |
| else: | |
| self.reshard_after_forward = reshard_after_forward | |
| if isinstance(self.reshard_after_forward, str): | |
| if self.fsdp_version == 2: | |
| self.reshard_after_forward = str_to_bool(self.reshard_after_forward.lower(), to_bool=True) | |
| else: | |
| # We need to remap based on custom enum values for user readability | |
| if self.reshard_after_forward.upper() in FSDP_SHARDING_STRATEGY: | |
| self.reshard_after_forward = FSDP_SHARDING_STRATEGY.index(self.reshard_after_forward.upper()) + 1 | |
| if isinstance(self.reshard_after_forward, int) or self.reshard_after_forward.isdigit(): | |
| self.reshard_after_forward = ShardingStrategy(int(self.reshard_after_forward)) | |
| else: | |
| self.reshard_after_forward = ShardingStrategy[self.reshard_after_forward.upper()] | |
| if self.fsdp_version == 2 and not isinstance(self.reshard_after_forward, bool): | |
| raise ValueError( | |
| f"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP2, please set to a `bool`" | |
| ) | |
| if self.fsdp_version == 1 and isinstance(self.reshard_after_forward, bool): | |
| raise ValueError( | |
| f"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP1, please set to a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`" | |
| ) | |
| if self.cpu_offload is None: | |
| self.cpu_offload = str_to_bool(os.environ.get(env_prefix + "OFFLOAD_PARAMS", "False")) == 1 | |
| self.set_cpu_offload() # abstracted away to hide imports due to version checks | |
| self.validate_cpu_offload() | |
| if self.backward_prefetch is None: | |
| self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", None) | |
| if isinstance(self.backward_prefetch, str) and self.backward_prefetch.upper() == "NO_PREFETCH": | |
| self.backward_prefetch = None | |
| if self.backward_prefetch is not None and not isinstance(self.backward_prefetch, BackwardPrefetch): | |
| if isinstance(self.backward_prefetch, str) and self.backward_prefetch.upper() in FSDP_BACKWARD_PREFETCH: | |
| self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1 | |
| if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit(): | |
| self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch)) | |
| else: | |
| self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()] | |
| if self.fsdp_version == 2 and self.backward_prefetch is not None: | |
| _fsdp2_warnings.add("backward_prefetch is not supported in FSDP2. Setting backward prefetch to None.") | |
| self.backward_prefetch = None | |
| self.set_state_dict_type() | |
| if self.auto_wrap_policy is None: | |
| self.auto_wrap_policy = os.environ.get(env_prefix + "AUTO_WRAP_POLICY", "NO_WRAP") | |
| if isinstance(self.auto_wrap_policy, str): | |
| if self.auto_wrap_policy.upper() not in FSDP_AUTO_WRAP_POLICY: | |
| raise ValueError( | |
| f"Invalid auto wrap policy: {self.auto_wrap_policy}. Must be one of {FSDP_AUTO_WRAP_POLICY}" | |
| ) | |
| from torch.distributed.fsdp.wrap import ( | |
| size_based_auto_wrap_policy, | |
| transformer_auto_wrap_policy, | |
| ) | |
| if self.auto_wrap_policy.upper() == "TRANSFORMER_BASED_WRAP": | |
| self.auto_wrap_policy = transformer_auto_wrap_policy | |
| if self.transformer_cls_names_to_wrap is None: | |
| self.transformer_cls_names_to_wrap = os.environ.get(env_prefix + "TRANSFORMER_CLS_TO_WRAP", None) | |
| if isinstance(self.transformer_cls_names_to_wrap, str): | |
| self.transformer_cls_names_to_wrap = self.transformer_cls_names_to_wrap.split(",") | |
| elif self.auto_wrap_policy.upper() == "SIZE_BASED_WRAP": | |
| self.auto_wrap_policy = size_based_auto_wrap_policy | |
| if self.min_num_params is None: | |
| self.min_num_params = int(os.environ.get(env_prefix + "MIN_NUM_PARAMS", 0)) | |
| elif not isinstance(self.min_num_params, int): | |
| raise ValueError( | |
| f"`min_num_params` must be an integer. Got {self.min_num_params} of type {type(self.min_num_params)}" | |
| ) | |
| elif self.auto_wrap_policy.upper() == "NO_WRAP": | |
| self.auto_wrap_policy = None | |
| if self.use_orig_params is None and self.fsdp_version == 1: | |
| self.use_orig_params = str_to_bool(os.environ.get(env_prefix + "USE_ORIG_PARAMS", "False")) == 1 | |
| if self.fsdp_version == 2 and self.use_orig_params is not None: | |
| _fsdp2_warnings.add("use_orig_params is obsolete in FSDP2, as FSDP2 always uses the original parameters.") | |
| self.use_orig_params = None | |
| if self.sync_module_states is None and self.fsdp_version == 1: | |
| self.sync_module_states = str_to_bool(os.environ.get(env_prefix + "SYNC_MODULE_STATES", "False")) == 1 | |
| if self.fsdp_version == 2 and self.sync_module_states is not None: | |
| _fsdp2_warnings.add( | |
| "sync_module_states is obsolete in FSDP2, as it is not needed anymore." | |
| "Setting sync_module_states to None." | |
| ) | |
| self.sync_module_states = None | |
| if self.forward_prefetch is None and self.fsdp_version == 1: | |
| self.forward_prefetch = str_to_bool(os.environ.get(env_prefix + "FORWARD_PREFETCH", "False")) == 1 | |
| if self.fsdp_version == 2 and self.forward_prefetch is not None: | |
| raise ValueError("forward_prefetch is not yet implemented in FSDP2, set to None or use `fsdp_version=1`") | |
| if self.activation_checkpointing is None: | |
| self.activation_checkpointing = ( | |
| str_to_bool(os.environ.get(env_prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1 | |
| ) | |
| if self.ignored_modules is None: | |
| self.ignored_modules = os.environ.get(env_prefix + "IGNORED_MODULES", None) | |
| if self.cpu_ram_efficient_loading is None: | |
| self.cpu_ram_efficient_loading = ( | |
| str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) == 1 | |
| ) | |
| else: | |
| # We still need to set it for transformers | |
| os.environ[env_prefix + "CPU_RAM_EFFICIENT_LOADING"] = str(self.cpu_ram_efficient_loading) | |
| # There's no need to specify sync_module_states in FSDP2 | |
| if self.fsdp_version == 1 and self.cpu_ram_efficient_loading and not self.sync_module_states: | |
| warnings.warn( | |
| "sync_module_states cannot be False since efficient cpu ram loading enabled. " | |
| "Setting sync_module_states to True." | |
| ) | |
| self.sync_module_states = True | |
| if isinstance(self.mixed_precision_policy, str): | |
| # override is True since self.mixed_precision_policy is not None | |
| # has to be overwritten with the correct mixed precision object | |
| self.set_mixed_precision(self.mixed_precision_policy, override=True) | |
| elif isinstance(self.mixed_precision_policy, dict): | |
| self.set_mixed_precision(self.mixed_precision_policy) | |
| if self.mixed_precision_policy is not None: | |
| self.validate_mixed_precision_policy() | |
| if self.sync_module_states: | |
| if is_npu_available(): | |
| device = torch.npu.current_device() | |
| elif is_mlu_available(): | |
| device = torch.mlu.current_device() | |
| elif is_musa_available(): | |
| device = torch.musa.current_device() | |
| elif is_cuda_available(): | |
| device = torch.cuda.current_device() | |
| elif is_xpu_available(): | |
| device = torch.xpu.current_device() | |
| elif is_hpu_available(): | |
| device = torch.hpu.current_device() | |
| else: | |
| raise RuntimeError( | |
| "There are currently no available devices found, must be one of 'XPU', 'CUDA', 'MLU', 'NPU', 'MUSA', or 'HPU'." | |
| ) | |
| # Create a function that will be used to initialize the parameters of the model | |
| # when using `sync_module_states` | |
| self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False) | |
| if is_torch_version("<", "2.7.0") and self.fsdp_version == 2 and self.ignored_modules is not None: | |
| _fsdp2_warnings.add( | |
| "FSDP2 ignored_params/ignored_modules is not available for torch version < 2.7.0" | |
| "Setting ignored_modules to None." | |
| ) | |
| self.ignored_modules = None | |
| # Single warning for all deprecation warnings due to FSDP2 conversion | |
| if _fsdp2_warnings: | |
| logger.warning("Multiple deprecation warnings due to FSDP2 conversion:\n".join(_fsdp2_warnings)) | |
| def set_state_dict_type(self, state_dict_type=None): | |
| """ | |
| Set the state dict config based on the `StateDictType`. | |
| """ | |
| from torch.distributed.fsdp.fully_sharded_data_parallel import ( | |
| FullOptimStateDictConfig, | |
| FullStateDictConfig, | |
| ShardedOptimStateDictConfig, | |
| ShardedStateDictConfig, | |
| StateDictType, | |
| ) | |
| # Override the state_dict_type if provided, typical use case: | |
| # user trains with sharded, but final save is with full | |
| if state_dict_type is not None: | |
| self.state_dict_type = state_dict_type | |
| if self.state_dict_type is None: | |
| self.state_dict_type = os.environ.get( | |
| "FSDP_STATE_DICT_TYPE", | |
| "FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT", | |
| ) | |
| if isinstance(self.state_dict_type, str): | |
| if self.state_dict_type.isdigit(): | |
| self.state_dict_type = StateDictType(int(self.state_dict_type)) | |
| else: | |
| self.state_dict_type = StateDictType[self.state_dict_type.upper()] | |
| if self.state_dict_type == StateDictType.FULL_STATE_DICT: | |
| if self.state_dict_config is None: | |
| self.state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) | |
| if self.optim_state_dict_config is None: | |
| self.optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True) | |
| elif self.state_dict_type == StateDictType.SHARDED_STATE_DICT: | |
| if self.state_dict_config is None: | |
| self.state_dict_config = ShardedStateDictConfig(offload_to_cpu=True) | |
| if self.optim_state_dict_config is None: | |
| self.optim_state_dict_config = ShardedOptimStateDictConfig(offload_to_cpu=True) | |
| if self.fsdp_version == 2 and self.state_dict_type == StateDictType.LOCAL_STATE_DICT: | |
| raise ValueError( | |
| "FSDP2 does not support LOCAL_STATE_DICT. " | |
| "Please set `fsdp_state_dict_type` to `SHARDED_STATE_DICT` or `FULL_STATE_DICT`." | |
| ) | |
| def set_auto_wrap_policy(self, model): | |
| """ | |
| Given `model`, creates an `auto_wrap_policy` based on the passed in policy and if we can use the | |
| `transformer_cls_to_wrap` | |
| """ | |
| from torch.distributed.fsdp.wrap import ( | |
| size_based_auto_wrap_policy, | |
| transformer_auto_wrap_policy, | |
| ) | |
| # First base off of `_no_split_modules` | |
| no_split_modules = getattr(model, "_no_split_modules", None) | |
| default_transformer_cls_names_to_wrap = list(no_split_modules) if no_split_modules is not None else [] | |
| if self.auto_wrap_policy == transformer_auto_wrap_policy: | |
| if self.transformer_cls_names_to_wrap is None: | |
| self.transformer_cls_names_to_wrap = default_transformer_cls_names_to_wrap | |
| transformer_cls_to_wrap = set() | |
| for layer_class in self.transformer_cls_names_to_wrap: | |
| transformer_cls = get_module_class_from_name(model, layer_class) | |
| if transformer_cls is None: | |
| raise ValueError(f"Could not find the transformer layer class {layer_class} in the model.") | |
| transformer_cls_to_wrap.add(transformer_cls) | |
| # Finally we set the auto_wrap_policy to a callable | |
| self.auto_wrap_policy = functools.partial( | |
| self.auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap | |
| ) | |
| elif self.auto_wrap_policy == size_based_auto_wrap_policy: | |
| # If zero, we silently ignore it. | |
| if self.min_num_params > 0: | |
| self.auto_wrap_policy = functools.partial(self.auto_wrap_policy, min_num_params=self.min_num_params) | |
| else: | |
| self.auto_wrap_policy = None | |
| def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=False): | |
| "Sets the mixed precision policy for FSDP" | |
| mixed_precision_mapping = { | |
| "fp8": torch.bfloat16, | |
| "fp16": torch.float16, | |
| "bf16": torch.bfloat16, | |
| "fp32": torch.float32, | |
| } | |
| dtype = mixed_precision | |
| if isinstance(mixed_precision, str): | |
| dtype = mixed_precision_mapping.get(mixed_precision, None) | |
| if dtype is None: | |
| raise ValueError( | |
| f"Invalid mixed precision: {mixed_precision}. Must be one of {list(mixed_precision_mapping.keys())}" | |
| ) | |
| elif isinstance(mixed_precision, torch.dtype) and mixed_precision not in mixed_precision_mapping.values(): | |
| raise ValueError( | |
| f"Invalid mixed precision: {mixed_precision}. Must be one of {list(mixed_precision_mapping.values())}" | |
| ) | |
| buffer_type = torch.float32 if buffer_autocast else dtype | |
| if self.fsdp_version == 1: | |
| from torch.distributed.fsdp import MixedPrecision | |
| elif self.fsdp_version == 2: | |
| from torch.distributed.fsdp import MixedPrecisionPolicy as MixedPrecision | |
| if override or self.mixed_precision_policy is None: | |
| dtype_args = {"param_dtype": dtype, "reduce_dtype": dtype} | |
| if self.fsdp_version == 1: | |
| dtype_args["buffer_dtype"] = buffer_type | |
| else: | |
| dtype_args["output_dtype"] = dtype | |
| # TODO(s1ro1): `cast_forward_inputs` for FSDP2? | |
| self.mixed_precision_policy = MixedPrecision(**dtype_args) | |
| elif isinstance(self.mixed_precision_policy, dict): | |
| # Check for incompatible types | |
| valid_keys = ["param_dtype", "reduce_dtype"] + ( | |
| ["buffer_dtype"] if self.fsdp_version == 1 else ["output_dtype"] | |
| ) | |
| missing_keys = [k for k in valid_keys if k not in self.mixed_precision_policy] | |
| invalid_values = [ | |
| k for k, v in self.mixed_precision_policy.items() if v not in mixed_precision_mapping.values() | |
| ] | |
| if missing_keys or invalid_values: | |
| raise ValueError( | |
| f"Invalid mixed precision policy: {self.mixed_precision_policy}. " | |
| f"Must be a `dict` with keys {valid_keys}." | |
| f"Values must be one of {list(mixed_precision_mapping.values())}" | |
| ) | |
| self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy) | |
| def validate_mixed_precision_policy(self): | |
| """ | |
| Validates the mixed precision policy, abstracted away to not bring in the imports if not needed. | |
| """ | |
| if self.fsdp_version == 2: | |
| from torch.distributed.fsdp import MixedPrecisionPolicy as MixedPrecision | |
| else: | |
| from torch.distributed.fsdp import MixedPrecision | |
| if not isinstance(self.mixed_precision_policy, MixedPrecision): | |
| required_type = ( | |
| "`torch.distributed.fsdp.MixedPrecisionPolicy`" | |
| if self.fsdp_version == 2 | |
| else "`torch.distributed.fsdp.MixedPrecision`" | |
| ) | |
| raise ValueError(f"mixed_precision_policy must be an instance of {required_type}.") | |
| def set_cpu_offload(self): | |
| if self.fsdp_version == 2: | |
| from torch.distributed.fsdp import CPUOffloadPolicy, OffloadPolicy | |
| else: | |
| from torch.distributed.fsdp import CPUOffload | |
| if isinstance(self.cpu_offload, bool): | |
| if self.fsdp_version == 2: | |
| if not self.cpu_offload: | |
| self.cpu_offload = OffloadPolicy() | |
| else: | |
| self.cpu_offload = CPUOffloadPolicy() | |
| else: | |
| self.cpu_offload = CPUOffload(offload_params=self.cpu_offload) | |
| def validate_cpu_offload(self): | |
| if self.fsdp_version == 2: | |
| from torch.distributed.fsdp import OffloadPolicy | |
| else: | |
| from torch.distributed.fsdp import CPUOffload | |
| if self.fsdp_version == 2 and not isinstance(self.cpu_offload, OffloadPolicy): | |
| raise ValueError( | |
| f"`cpu_offload` must be an instance of `torch.distributed.fsdp.OffloadPolicy` in FSDP2, got {self.cpu_offload}" | |
| ) | |
| if self.fsdp_version == 1 and not isinstance(self.cpu_offload, CPUOffload): | |
| raise ValueError( | |
| f"`cpu_offload` must be an instance of `torch.distributed.fsdp.CPUOffload` in FSDP1, got {self.cpu_offload}" | |
| ) | |
| class TorchTensorParallelPlugin: | |
| """ | |
| This plugin is used to enable tensor parallelism using PyTorch >= 2.0. | |
| """ | |
| tp_size: int = field( | |
| default=1, | |
| metadata={"help": "tensor parallel size will be used in the device mesh preparation"}, | |
| ) | |
| # torch_device_mesh is of type "torch.distributed.DeviceMesh" | |
| torch_device_mesh: Optional["torch.distributed.DeviceMesh"] = field(default=None) | |
| class TorchContextParallelConfig: | |
| """ | |
| This class holds the configuration for context parallelism in PyTorch. | |
| """ | |
| cp_comm_strategy: Optional[str] = field( | |
| default=None, | |
| metadata={ | |
| "help": "Communication strategy for context parallelism. Can be one of 'allgather' or 'alltoall'. Defaults to 'allgather'." | |
| }, | |
| ) | |
| def __post_init__(self): | |
| if not is_torch_version(">=", BETA_CP_AVAILABLE_PYTORCH_VERSION): | |
| raise ValueError( | |
| f"FSDP2-based Context parallelism is only available in PyTorch {BETA_CP_AVAILABLE_PYTORCH_VERSION} and later versions. " | |
| "Please upgrade your PyTorch version." | |
| ) | |
| if self.cp_comm_strategy is None: | |
| self.cp_comm_strategy = os.environ.get("PARALLELISM_CONFIG_CP_COMM_STRATEGY", "allgather") | |
| if self.cp_comm_strategy not in ["allgather", "alltoall"]: | |
| raise ValueError( | |
| f"Invalid cp_comm_strategy: {self.cp_comm_strategy}. Must be one of 'allgather' or 'alltoall'." | |
| ) | |
| class DeepSpeedSequenceParallelConfig: | |
| sp_seq_length: Optional[int] = field( | |
| default=None, | |
| metadata={ | |
| "help": "Sequence length for when batches are all of the same length. For variable sequence lengths across batches set `sp_seq_length_is_variable=True` and leave this field unset" | |
| }, | |
| ) | |
| sp_seq_length_is_variable: Optional[bool] = field( | |
| default=None, | |
| metadata={ | |
| "help": "If `True` will work with a sequence length that may change between batches, in which case `sp_seq_length` value can be set to anything divisible by cp size or remain unset. If `False` then `sp_seq_length` needs to match the batch's sequence length dimension. The default is `True`." | |
| }, | |
| ) | |
| sp_attn_implementation: Optional[str] = field( | |
| default=None, | |
| metadata={ | |
| "help": "Attention implementation to use. Can be one of 'flash_attention_2', 'flash_attention_3' or 'sdpa'. Defaults to `sdpa`." | |
| }, | |
| ) | |
| def __post_init__(self): | |
| # sp_seq_length_is_variable and sp_seq_length are interconnected | |
| if self.sp_seq_length_is_variable is None: | |
| self.sp_seq_length_is_variable = ( | |
| os.environ.get("PARALLELISM_CONFIG_SP_SEQ_LENGTH_IS_VARIABLE", "true").lower() == "true" | |
| ) | |
| if not self.sp_seq_length_is_variable and self.sp_seq_length is None: | |
| if "PARALLELISM_CONFIG_SP_SEQ_LENGTH" not in os.environ: | |
| raise ValueError( | |
| "when `sp_seq_length_is_variable` is `False` `sp_seq_length` must be provided either through the constructor or the environment variable PARALLELISM_CONFIG_SP_SEQ_LENGTH" | |
| ) | |
| else: | |
| self.sp_seq_length = os.environ.get("PARALLELISM_CONFIG_SP_SEQ_LENGTH") | |
| self.sp_seq_length = None if self.sp_seq_length == "None" else int(self.sp_seq_length) | |
| if self.sp_attn_implementation is None: | |
| self.sp_attn_implementation = os.environ.get("PARALLELISM_CONFIG_SP_ATTN_IMPLEMENTATION", None) | |
| if self.sp_attn_implementation is not None and self.sp_attn_implementation not in [ | |
| "flash_attention_2", | |
| "flash_attention_3", | |
| "sdpa", | |
| ]: | |
| raise ValueError( | |
| f"Invalid sp_attn_implementation: {self.sp_attn_implementation}. Must be one of 'flash_attention_2', 'flash_attention_3' or 'sdpa'." | |
| ) | |
| class TorchTensorParallelConfig: | |
| """ | |
| Use this object in your [`Accelerator`] to customize your torch tensor parallelism. | |
| """ | |
| enable_async_tp: bool = False | |
| def __post_init__(self): | |
| if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION): | |
| raise ValueError( | |
| f"Torch tensor parallelism is only available in PyTorch {BETA_TP_AVAILABLE_PYTORCH_VERSION} and later versions. " | |
| "Please upgrade your PyTorch version." | |
| ) | |
| if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION): | |
| raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}") | |
| if self.enable_async_tp: | |
| warnings.warn("Async tensor parallelism is currently not supported, ignoring this option.") | |
| class MegatronLMPlugin: | |
| """ | |
| Plugin for Megatron-LM to enable tensor, pipeline, sequence and data parallelism. Also to enable selective | |
| activation recomputation and optimized fused kernels. | |
| Args: | |
| tp_degree (`int`, defaults to `None`): | |
| Tensor parallelism degree. | |
| pp_degree (`int`, defaults to `None`): | |
| Pipeline parallelism degree. | |
| num_micro_batches (`int`, defaults to `None`): | |
| Number of micro-batches. | |
| gradient_clipping (`float`, defaults to `None`): | |
| Gradient clipping value based on global L2 Norm (0 to disable). | |
| sequence_parallelism (`bool`, defaults to `None`): | |
| Enable sequence parallelism. | |
| recompute_activations (`bool`, defaults to `None`): | |
| Enable selective activation recomputation. | |
| use_distributed_optimizr (`bool`, defaults to `None`): | |
| Enable distributed optimizer. | |
| pipeline_model_parallel_split_rank (`int`, defaults to `None`): | |
| Rank where encoder and decoder should be split. | |
| num_layers_per_virtual_pipeline_stage (`int`, defaults to `None`): | |
| Number of layers per virtual pipeline stage. | |
| is_train_batch_min (`str`, defaults to `True`): | |
| If both tran & eval dataloaders are specified, this will decide the `micro_batch_size`. | |
| train_iters (`int`, defaults to `None`): | |
| Total number of samples to train over all training runs. Note that either train-iters or train-samples | |
| should be provided when using `MegatronLMDummyScheduler`. | |
| train_samples (`int`, defaults to `None`): | |
| Total number of samples to train over all training runs. Note that either train-iters or train-samples | |
| should be provided when using `MegatronLMDummyScheduler`. | |
| weight_decay_incr_style (`str`, defaults to `'constant'`): | |
| Weight decay increment function. choices=["constant", "linear", "cosine"]. | |
| start_weight_decay (`float`, defaults to `None`): | |
| Initial weight decay coefficient for L2 regularization. | |
| end_weight_decay (`float`, defaults to `None`): | |
| End of run weight decay coefficient for L2 regularization. | |
| lr_decay_style (`str`, defaults to `'linear'`): | |
| Learning rate decay function. choices=['constant', 'linear', 'cosine']. | |
| lr_decay_iters (`int`, defaults to `None`): | |
| Number of iterations for learning rate decay. If None defaults to `train_iters`. | |
| lr_decay_samples (`int`, defaults to `None`): | |
| Number of samples for learning rate decay. If None defaults to `train_samples`. | |
| lr_warmup_iters (`int`, defaults to `None`): | |
| Number of iterations to linearly warmup learning rate over. | |
| lr_warmup_samples (`int`, defaults to `None`): | |
| Number of samples to linearly warmup learning rate over. | |
| lr_warmup_fraction (`float`, defaults to `None`): | |
| Fraction of lr-warmup-(iters/samples) to linearly warmup learning rate over. | |
| min_lr (`float`, defaults to `0`): | |
| Minimum value for learning rate. The scheduler clip values below this threshold. | |
| consumed_samples (`List`, defaults to `None`): | |
| Number of samples consumed in the same order as the dataloaders to `accelerator.prepare` call. | |
| no_wd_decay_cond (`Optional`, defaults to `None`): | |
| Condition to disable weight decay. | |
| scale_lr_cond (`Optional`, defaults to `None`): | |
| Condition to scale learning rate. | |
| lr_mult (`float`, defaults to `1.0`): | |
| Learning rate multiplier. | |
| megatron_dataset_flag (`bool`, defaults to `False`): | |
| Whether the format of dataset follows Megatron-LM Indexed/Cached/MemoryMapped format. | |
| seq_length (`int`, defaults to `None`): | |
| Maximum sequence length to process. | |
| encoder_seq_length (`int`, defaults to `None`): | |
| Maximum sequence length to process for the encoder. | |
| decoder_seq_length (`int`, defaults to `None`): | |
| Maximum sequence length to process for the decoder. | |
| tensorboard_dir (`str`, defaults to `None`): | |
| Path to save tensorboard logs. | |
| set_all_logging_options (`bool`, defaults to `False`): | |
| Whether to set all logging options. | |
| eval_iters (`int`, defaults to `100`): | |
| Number of iterations to run for evaluation validation/test for. | |
| eval_interval (`int`, defaults to `1000`): | |
| Interval between running evaluation on validation set. | |
| return_logits (`bool`, defaults to `False`): | |
| Whether to return logits from the model. | |
| custom_train_step_class (`Optional`, defaults to `None`): | |
| Custom train step class. | |
| custom_train_step_kwargs (`Optional`, defaults to `None`): | |
| Custom train step kwargs. | |
| custom_model_provider_function (`Optional`, defaults to `None`): | |
| Custom model provider function. | |
| custom_prepare_model_function (`Optional`, defaults to `None`): | |
| Custom prepare model function. | |
| custom_megatron_datasets_provider_function (`Optional`, defaults to `None`): | |
| Custom megatron train_valid_test datasets provider function. | |
| custom_get_batch_function (`Optional`, defaults to `None`): | |
| Custom get batch function. | |
| custom_loss_function (`Optional`, defaults to `None`): | |
| Custom loss function. | |
| other_megatron_args (`Optional`, defaults to `None`): | |
| Other Megatron-LM arguments. Please refer Megatron-LM. | |
| """ | |
| tp_degree: int = field(default=None, metadata={"help": "tensor parallelism degree."}) | |
| pp_degree: int = field(default=None, metadata={"help": "pipeline parallelism degree."}) | |
| use_custom_fsdp: bool = field(default=None, metadata={"help": "use custom fsdp."}) | |
| overlap_cpu_optimizer_d2h_h2d: bool = field( | |
| default=None, metadata={"help": "overlap CPU optimizer step, gradients D2H and updated parameters H2D."} | |
| ) | |
| no_load_optim: bool = field(default=None, metadata={"help": "do not load optimizer."}) | |
| eod_mask_loss: bool = field(default=None, metadata={"help": "use eod mask loss."}) | |
| no_save_optim: bool = field(default=None, metadata={"help": "do not save optimizer."}) | |
| optimizer_cpu_offload: bool = field(default=None, metadata={"help": "use CPU offload for optimizer."}) | |
| use_precision_aware_optimizer: bool = field(default=None, metadata={"help": "use precision aware optimizer."}) | |
| decoder_last_pipeline_num_layers: int = field( | |
| default=None, | |
| metadata={ | |
| "help": "decoder last pipeline number of layers, default None is even split of transformer layers across all pipeline stages." | |
| }, | |
| ) | |
| recompute_granularity: str = field(default=None, metadata={"help": "recompute granularity (full, selective)."}) | |
| recompute_method: str = field(default=None, metadata={"help": "recompute method (uniform, block)."}) | |
| recompute_num_layers: int = field(default=None, metadata={"help": "number of layers to recompute."}) | |
| attention_backend: bool = field(default=None, metadata={"help": "enable attention backend."}) | |
| expert_model_parallel_size: int = field(default=None, metadata={"help": "expert model parallel size."}) | |
| context_parallel_size: int = field(default=None, metadata={"help": "context parallel size."}) | |
| attention_dropout: float = field(default=None, metadata={"help": "attention dropout rate."}) | |
| hidden_dropout: float = field(default=None, metadata={"help": "hidden dropout rate."}) | |
| attention_softmax_in_fp32: bool = field(default=None, metadata={"help": "use fp32 for attention softmax."}) | |
| expert_tensor_parallel_size: int = field(default=None, metadata={"help": "expert tensor parallel size."}) | |
| calculate_per_token_loss: bool = field(default=None, metadata={"help": "calculate per token loss."}) | |
| use_rotary_position_embeddings: bool = field(default=None, metadata={"help": "use rotary position embeddings."}) | |
| num_micro_batches: int = field(default=None, metadata={"help": "number of micro-batches."}) | |
| gradient_clipping: float = field( | |
| default=None, | |
| metadata={"help": "gradient clipping value based on global L2 Norm (0 to disable)"}, | |
| ) | |
| sequence_parallelism: bool = field( | |
| default=None, | |
| metadata={"help": "enable sequence parallelism"}, | |
| ) | |
| recompute_activations: bool = field( | |
| default=None, | |
| metadata={"help": "enable selective activation recomputation"}, | |
| ) | |
| use_distributed_optimizer: bool = field( | |
| default=None, | |
| metadata={"help": "enable distributed optimizer"}, | |
| ) | |
| pipeline_model_parallel_split_rank: int = field( | |
| default=None, | |
| metadata={"help": "Rank where encoder and decoder should be split."}, | |
| ) | |
| num_layers_per_virtual_pipeline_stage: int = field( | |
| default=None, metadata={"help": "Number of layers per virtual pipeline stage."} | |
| ) | |
| is_train_batch_min: str = field( | |
| default=True, | |
| metadata={"help": "If both train & eval dataloaders are specified, this will decide the micro_batch_size"}, | |
| ) | |
| train_iters: int = field( | |
| default=None, | |
| metadata={ | |
| "help": "Total number of iterations to train over all training runs. " | |
| "Note that either train-iters or train-samples should be provided when using `MegatronLMDummyScheduler`" | |
| }, | |
| ) | |
| train_samples: int = field( | |
| default=None, | |
| metadata={ | |
| "help": "Total number of samples to train over all training runs. " | |
| "Note that either train-iters or train-samples should be provided when using `MegatronLMDummyScheduler`" | |
| }, | |
| ) | |
| weight_decay_incr_style: str = field( | |
| default="constant", | |
| metadata={"help": 'Weight decay increment function. choices=["constant", "linear", "cosine"]. '}, | |
| ) | |
| start_weight_decay: float = field( | |
| default=None, | |
| metadata={"help": "Initial weight decay coefficient for L2 regularization."}, | |
| ) | |
| end_weight_decay: float = field( | |
| default=None, | |
| metadata={"help": "End of run weight decay coefficient for L2 regularization."}, | |
| ) | |
| lr_decay_style: str = field( | |
| default="linear", | |
| metadata={"help": "Learning rate decay function. choices=['constant', 'linear', 'cosine']."}, | |
| ) | |
| lr_decay_iters: int = field( | |
| default=None, | |
| metadata={"help": "Number of iterations for learning rate decay. If None defaults to `train_iters`."}, | |
| ) | |
| lr_decay_samples: int = field( | |
| default=None, | |
| metadata={"help": "Number of samples for learning rate decay. If None defaults to `train_samples`."}, | |
| ) | |
| lr_warmup_iters: int = field( | |
| default=None, | |
| metadata={"help": "number of iterations to linearly warmup learning rate over."}, | |
| ) | |
| lr_warmup_samples: int = field( | |
| default=None, | |
| metadata={"help": "number of samples to linearly warmup learning rate over."}, | |
| ) | |
| lr_warmup_fraction: float = field( | |
| default=None, | |
| metadata={"help": "fraction of lr-warmup-(iters/samples) to linearly warmup learning rate over."}, | |
| ) | |
| min_lr: float = field( | |
| default=0, | |
| metadata={"help": "Minimum value for learning rate. The scheduler clip values below this threshold."}, | |
| ) | |
| consumed_samples: list[int] = field( | |
| default=None, | |
| metadata={ | |
| "help": "Number of samples consumed in the same order as the dataloaders to `accelerator.prepare` call." | |
| }, | |
| ) | |
| no_wd_decay_cond: Optional[Callable] = field(default=None, metadata={"help": "Condition to disable weight decay."}) | |
| scale_lr_cond: Optional[Callable] = field(default=None, metadata={"help": "Condition to scale learning rate."}) | |
| lr_mult: float = field(default=1.0, metadata={"help": "Learning rate multiplier."}) | |
| megatron_dataset_flag: bool = field( | |
| default=False, | |
| metadata={"help": "Whether the format of dataset follows Megatron-LM Indexed/Cached/MemoryMapped format."}, | |
| ) | |
| seq_length: int = field( | |
| default=None, | |
| metadata={"help": "Maximum sequence length to process."}, | |
| ) | |
| encoder_seq_length: int = field( | |
| default=None, | |
| metadata={"help": "Maximum sequence length to process for the encoder."}, | |
| ) | |
| decoder_seq_length: int = field( | |
| default=None, | |
| metadata={"help": "Maximum sequence length to process for the decoder."}, | |
| ) | |
| tensorboard_dir: str = field( | |
| default=None, | |
| metadata={"help": "Path to save tensorboard logs."}, | |
| ) | |
| set_all_logging_options: bool = field( | |
| default=False, | |
| metadata={"help": "Whether to set all logging options."}, | |
| ) | |
| eval_iters: int = field( | |
| default=100, | |
| metadata={"help": "Number of iterations to run for evaluation validation/test for."}, | |
| ) | |
| eval_interval: int = field( | |
| default=1000, | |
| metadata={"help": "Interval between running evaluation on validation set."}, | |
| ) | |
| return_logits: bool = field( | |
| default=False, | |
| metadata={"help": "Whether to return logits from the model."}, | |
| ) | |
| # custom train step args | |
| custom_train_step_class: Optional[Any] = field( | |
| default=None, | |
| metadata={"help": "Custom train step class."}, | |
| ) | |
| custom_train_step_kwargs: Optional[dict[str, Any]] = field( | |
| default=None, | |
| metadata={"help": "Custom train step kwargs."}, | |
| ) | |
| # custom model args | |
| custom_model_provider_function: Optional[Callable] = field( | |
| default=None, | |
| metadata={"help": "Custom model provider function."}, | |
| ) | |
| custom_prepare_model_function: Optional[Callable] = field( | |
| default=None, | |
| metadata={"help": "Custom prepare model function."}, | |
| ) | |
| custom_megatron_datasets_provider_function: Optional[Callable] = field( | |
| default=None, | |
| metadata={"help": "Custom megatron train_valid_test datasets provider function."}, | |
| ) | |
| custom_get_batch_function: Optional[Callable] = field( | |
| default=None, | |
| metadata={"help": "Custom get batch function."}, | |
| ) | |
| custom_loss_function: Optional[Callable] = field( | |
| default=None, | |
| metadata={"help": "Custom loss function."}, | |
| ) | |
| # remaining args such as enabling Alibi/ROPE positional embeddings, | |
| # wandb logging, Multi-Query Attention, etc. | |
| other_megatron_args: Optional[dict[str, Any]] = field( | |
| default=None, | |
| metadata={"help": "Other Megatron-LM arguments. Please refer Megatron-LM"}, | |
| ) | |
| def __post_init__(self): | |
| prefix = "MEGATRON_LM_" | |
| if self.tp_degree is None: | |
| self.tp_degree = int(os.environ.get(prefix + "TP_DEGREE", 1)) | |
| if self.pp_degree is None: | |
| self.pp_degree = int(os.environ.get(prefix + "PP_DEGREE", 1)) | |
| if self.use_custom_fsdp is None: | |
| self.use_custom_fsdp = str_to_bool(os.environ.get(prefix + "USE_CUSTOM_FSDP", "False")) == 1 | |
| if self.no_load_optim is None: | |
| self.no_load_optim = str_to_bool(os.environ.get(prefix + "NO_LOAD_OPTIM", "False")) == 1 | |
| if self.eod_mask_loss is None: | |
| self.eod_mask_loss = str_to_bool(os.environ.get(prefix + "EOD_MASK_LOSS", "False")) == 1 | |
| if self.no_save_optim is None: | |
| self.no_save_optim = str_to_bool(os.environ.get(prefix + "NO_SAVE_OPTIM", "False")) == 1 | |
| if self.optimizer_cpu_offload is None: | |
| self.optimizer_cpu_offload = str_to_bool(os.environ.get(prefix + "OPTIMIZER_CPU_OFFLOAD", "False")) == 1 | |
| if self.overlap_cpu_optimizer_d2h_h2d is None: | |
| self.overlap_cpu_optimizer_d2h_h2d = ( | |
| str_to_bool(os.environ.get(prefix + "OVERLAP_CPU_OPTIMIZER_D2H_H2D", "False")) == 1 | |
| ) | |
| if self.use_precision_aware_optimizer is None: | |
| self.use_precision_aware_optimizer = ( | |
| str_to_bool(os.environ.get(prefix + "USE_PRECISION_AWARE_OPTIMIZER", "False")) == 1 | |
| ) | |
| if self.decoder_last_pipeline_num_layers is None: | |
| if os.environ.get(prefix + "DECODER_LAST_PIPELINE_NUM_LAYERS") is not None: | |
| self.decoder_last_pipeline_num_layers = int( | |
| os.environ.get(prefix + "DECODER_LAST_PIPELINE_NUM_LAYERS", 0) | |
| ) | |
| else: | |
| self.decoder_last_pipeline_num_layers = None | |
| if self.num_micro_batches is None: | |
| self.num_micro_batches = int(os.environ.get(prefix + "NUM_MICRO_BATCHES", 1)) | |
| if self.gradient_clipping is None: | |
| self.gradient_clipping = float(os.environ.get(prefix + "GRADIENT_CLIPPING", 1.0)) | |
| if self.recompute_activations is None: | |
| self.recompute_activations = str_to_bool(os.environ.get(prefix + "RECOMPUTE_ACTIVATIONS", "False")) == 1 | |
| if self.use_distributed_optimizer is None: | |
| self.use_distributed_optimizer = ( | |
| str_to_bool(os.environ.get(prefix + "USE_DISTRIBUTED_OPTIMIZER", "False")) == 1 | |
| ) | |
| if self.sequence_parallelism is None: | |
| self.sequence_parallelism = str_to_bool(os.environ.get(prefix + "SEQUENCE_PARALLELISM", "False")) == 1 | |
| if self.recompute_granularity is None: | |
| self.recompute_granularity = os.environ.get(prefix + "RECOMPUTE_GRANULARITY", "full") | |
| if self.recompute_method is None: | |
| self.recompute_method = os.environ.get(prefix + "RECOMPUTE_METHOD", "uniform") | |
| if self.recompute_num_layers is None: | |
| self.recompute_num_layers = int(os.environ.get(prefix + "RECOMPUTE_NUM_LAYERS", 1)) | |
| if self.attention_backend is None: | |
| self.attention_backend = str_to_bool(os.environ.get(prefix + "ATTENTION_BACKEND", "True")) == 1 | |
| if self.expert_model_parallel_size is None: | |
| self.expert_model_parallel_size = int(os.environ.get(prefix + "EXPERT_MODEL_PARALLEL_SIZE", 1)) | |
| if self.context_parallel_size is None: | |
| self.context_parallel_size = int(os.environ.get(prefix + "CONTEXT_PARALLEL_SIZE", 2)) | |
| if self.attention_dropout is None: | |
| self.attention_dropout = float(os.environ.get(prefix + "ATTENTION_DROPOUT", "0.0")) | |
| if self.hidden_dropout is None: | |
| self.hidden_dropout = float(os.environ.get(prefix + "HIDDEN_DROPOUT", "0.0")) | |
| if self.attention_softmax_in_fp32 is None: | |
| self.attention_softmax_in_fp32 = ( | |
| str_to_bool(os.environ.get(prefix + "ATTENTION_SOFTMAX_IN_FP32", "True")) == 1 | |
| ) | |
| if self.expert_tensor_parallel_size is None: | |
| self.expert_tensor_parallel_size = int(os.environ.get(prefix + "EXPERT_TENSOR_PARALLEL_SIZE", 1)) | |
| if self.calculate_per_token_loss is None: | |
| self.calculate_per_token_loss = ( | |
| str_to_bool(os.environ.get(prefix + "CALCULATE_PER_TOKEN_LOSS", "True")) == 1 | |
| ) | |
| if self.use_rotary_position_embeddings is None: | |
| self.use_rotary_position_embeddings = ( | |
| str_to_bool(os.environ.get(prefix + "USE_ROTARY_POSITION_EMBEDDINGS", "True")) == 1 | |
| ) | |
| if self.pp_degree > 1 or self.use_distributed_optimizer: | |
| self.DDP_impl = "local" | |
| else: | |
| self.DDP_impl = "torch" | |
| if self.consumed_samples is not None: | |
| if len(self.consumed_samples) == 1: | |
| self.consumed_samples.extend([0, 0]) | |
| elif len(self.consumed_samples) == 2: | |
| self.consumed_samples.append(0) | |
| self.megatron_lm_default_args = { | |
| "tensor_model_parallel_size": self.tp_degree, | |
| "pipeline_model_parallel_size": self.pp_degree, | |
| "pipeline_model_parallel_split_rank": self.pipeline_model_parallel_split_rank, | |
| "num_layers_per_virtual_pipeline_stage": self.num_layers_per_virtual_pipeline_stage, | |
| "DDP_impl": self.DDP_impl, | |
| "use_distributed_optimizer": self.use_distributed_optimizer, | |
| "sequence_parallel": self.sequence_parallelism, | |
| "clip_grad": self.gradient_clipping, | |
| "num_micro_batches": self.num_micro_batches, | |
| "consumed_samples": self.consumed_samples, | |
| "no_wd_decay_cond": self.no_wd_decay_cond, | |
| "scale_lr_cond": self.scale_lr_cond, | |
| "lr_mult": self.lr_mult, | |
| "megatron_dataset_flag": self.megatron_dataset_flag, | |
| "eval_iters": self.eval_iters, | |
| "eval_interval": self.eval_interval, | |
| "use_custom_fsdp": self.use_custom_fsdp, | |
| "no_load_optim": self.no_load_optim, | |
| "eod_mask_loss": self.eod_mask_loss, | |
| "no_save_optim": self.no_save_optim, | |
| "optimizer_cpu_offload": self.optimizer_cpu_offload, | |
| "overlap_cpu_optimizer_d2h_h2d": self.overlap_cpu_optimizer_d2h_h2d, | |
| "use_precision_aware_optimizer": self.use_precision_aware_optimizer, | |
| "decoder_last_pipeline_num_layers": self.decoder_last_pipeline_num_layers, | |
| "recompute_granularity": self.recompute_granularity, | |
| "recompute_method": self.recompute_method, | |
| "recompute_num_layers": self.recompute_num_layers, | |
| "attention_backend": self.attention_backend, | |
| "expert_model_parallel_size": self.expert_model_parallel_size, | |
| "context_parallel_size": self.context_parallel_size, | |
| "attention_dropout": self.attention_dropout, | |
| "hidden_dropout": self.hidden_dropout, | |
| "attention_softmax_in_fp32": self.attention_softmax_in_fp32, | |
| "expert_tensor_parallel_size": self.expert_tensor_parallel_size, | |
| "calculate_per_token_loss": self.calculate_per_token_loss, | |
| "use_rotary_position_embeddings": self.use_rotary_position_embeddings, | |
| } | |
| if self.tensorboard_dir is not None: | |
| self.megatron_lm_default_args["tensorboard_dir"] = self.tensorboard_dir | |
| if self.set_all_logging_options: | |
| self.set_tensorboard_logging_options() | |
| if self.other_megatron_args is not None: | |
| self.megatron_lm_default_args.update(self.other_megatron_args) | |
| def set_network_size_args(self, model, batch_data=None): | |
| model_config_type = model.config.model_type.lower() | |
| for model_type in MODEL_CONFIGS_TO_MEGATRON_PARSERS.keys(): | |
| if model_type in model_config_type: | |
| MODEL_CONFIGS_TO_MEGATRON_PARSERS[model_type](self, model, batch_data) | |
| return | |
| raise ValueError( | |
| f"Accelerate Megatron-LM integration not supports {model_config_type} model. " | |
| "You can add your own model config parser." | |
| ) | |
| def set_mixed_precision(self, mixed_precision): | |
| if mixed_precision == "fp16": | |
| self.megatron_lm_default_args["fp16"] = True | |
| elif mixed_precision == "bf16": | |
| self.megatron_lm_default_args["bf16"] = True | |
| self.DDP_impl = "local" | |
| self.megatron_lm_default_args["DDP_impl"] = self.DDP_impl | |
| def set_training_args(self, micro_batch_size, dp_degree): | |
| self.data_parallel_size = dp_degree | |
| self.micro_batch_size = micro_batch_size | |
| self.global_batch_size = dp_degree * micro_batch_size * self.num_micro_batches | |
| self.megatron_lm_default_args["data_parallel_size"] = self.data_parallel_size | |
| self.megatron_lm_default_args["micro_batch_size"] = self.micro_batch_size | |
| self.megatron_lm_default_args["global_batch_size"] = self.global_batch_size | |
| def set_optimizer_type(self, optimizer): | |
| optimizer_name = optimizer.__class__.__name__.lower() | |
| if "adam" in optimizer_name: | |
| self.megatron_lm_default_args["optimizer"] = "adam" | |
| self.megatron_lm_default_args["adam_beta1"] = optimizer.defaults["betas"][0] | |
| self.megatron_lm_default_args["adam_beta2"] = optimizer.defaults["betas"][1] | |
| self.megatron_lm_default_args["adam_eps"] = optimizer.defaults["eps"] | |
| elif "sgd" in optimizer_name: | |
| self.megatron_lm_default_args["optimizer"] = "sgd" | |
| self.megatron_lm_default_args["sgd_momentum"] = optimizer.defaults["momentum"] | |
| else: | |
| raise ValueError(f"Optimizer {optimizer_name} is not supported by Megatron-LM") | |
| self.megatron_lm_default_args["lr"] = optimizer.defaults["lr"] | |
| self.megatron_lm_default_args["weight_decay"] = optimizer.defaults["weight_decay"] | |
| def set_scheduler_args(self, scheduler): | |
| if self.train_iters is None: | |
| self.train_iters = scheduler.total_num_steps // self.megatron_lm_default_args["data_parallel_size"] | |
| if self.train_samples is not None: | |
| self.train_samples = None | |
| warnings.warn( | |
| "Ignoring `train_samples` as `train_iters` based on scheduler is being used for training." | |
| ) | |
| if self.lr_warmup_iters is None: | |
| self.lr_warmup_iters = scheduler.warmup_num_steps // self.megatron_lm_default_args["data_parallel_size"] | |
| if self.lr_warmup_samples is not None: | |
| warnings.warn( | |
| "Ignoring `lr_warmup_samples` as `lr_warmup_iters` based on scheduler is being used for training." | |
| ) | |
| self.lr_warmup_samples = 0 | |
| self.megatron_lm_default_args["train_iters"] = self.train_iters | |
| self.megatron_lm_default_args["lr_warmup_iters"] = self.lr_warmup_iters | |
| self.megatron_lm_default_args["train_samples"] = self.train_samples | |
| self.megatron_lm_default_args["lr_warmup_samples"] = self.lr_warmup_samples | |
| self.megatron_lm_default_args["lr_decay_iters"] = self.lr_decay_iters | |
| self.megatron_lm_default_args["lr_decay_samples"] = self.lr_decay_samples | |
| self.megatron_lm_default_args["lr_warmup_fraction"] = self.lr_warmup_fraction | |
| self.megatron_lm_default_args["lr_decay_style"] = self.lr_decay_style | |
| self.megatron_lm_default_args["weight_decay_incr_style"] = self.weight_decay_incr_style | |
| self.megatron_lm_default_args["start_weight_decay"] = self.start_weight_decay | |
| self.megatron_lm_default_args["end_weight_decay"] = self.end_weight_decay | |
| self.megatron_lm_default_args["min_lr"] = self.min_lr | |
| def set_tensorboard_logging_options(self): | |
| from megatron.training.arguments import _add_logging_args | |
| parser = argparse.ArgumentParser() | |
| parser = _add_logging_args(parser) | |
| logging_args = parser.parse_known_args() | |
| self.dataset_args = vars(logging_args[0]) | |
| for key, value in self.dataset_args.items(): | |
| if key.startswith("log_"): | |
| self.megatron_lm_default_args[key] = True | |
| elif key.startswith("no_log_"): | |
| self.megatron_lm_default_args[key.replace("no_", "")] = True | |
| MODEL_CONFIGS_TO_MEGATRON_PARSERS = {} | |
| def add_model_config_to_megatron_parser(model_type: str): | |
| def add_model_config_parser_helper(func): | |
| def wrapper(*args, **kwargs): | |
| return func(*args, **kwargs) | |
| MODEL_CONFIGS_TO_MEGATRON_PARSERS[model_type] = func | |
| return wrapper | |
| return add_model_config_parser_helper | |
| def parse_bert_config(megatron_lm_plugin, model, batch_data): | |
| model_type_name = "bert" | |
| num_layers = model.config.num_hidden_layers | |
| hidden_size = model.config.hidden_size | |
| num_attention_heads = model.config.num_attention_heads | |
| max_position_embeddings = model.config.max_position_embeddings | |
| num_labels = model.config.num_labels | |
| orig_vocab_size = model.config.vocab_size | |
| pretraining_flag = False | |
| if "maskedlm" in model.__class__.__name__.lower(): | |
| pretraining_flag = True | |
| if megatron_lm_plugin.seq_length is not None: | |
| if megatron_lm_plugin.encoder_seq_length is not None: | |
| warnings.warn("Both `seq_length` and `encoder_seq_length` are set. Using `encoder_seq_length`.") | |
| megatron_lm_plugin.seq_length = megatron_lm_plugin.encoder_seq_length | |
| elif megatron_lm_plugin.encoder_seq_length is not None: | |
| megatron_lm_plugin.seq_length = megatron_lm_plugin.encoder_seq_length | |
| elif batch_data is not None: | |
| megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1] | |
| else: | |
| megatron_lm_plugin.seq_length = max_position_embeddings | |
| megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length | |
| megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name | |
| megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers | |
| megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size | |
| megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads | |
| megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings | |
| megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag | |
| megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size | |
| megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict | |
| megatron_lm_plugin.megatron_lm_default_args["num_labels"] = num_labels | |
| def parse_gpt2_config(megatron_lm_plugin, model, batch_data): | |
| model_type_name = "gpt" | |
| num_layers = model.config.n_layer | |
| hidden_size = model.config.n_embd | |
| num_attention_heads = model.config.n_head | |
| max_position_embeddings = model.config.n_positions | |
| orig_vocab_size = model.config.vocab_size | |
| pretraining_flag = True | |
| if megatron_lm_plugin.seq_length is not None: | |
| if megatron_lm_plugin.decoder_seq_length is not None: | |
| warnings.warn("Both `seq_length` and `decoder_seq_length` are set. Using `decoder_seq_length`.") | |
| megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length | |
| elif megatron_lm_plugin.decoder_seq_length is not None: | |
| megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length | |
| elif batch_data is not None: | |
| megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1] | |
| else: | |
| megatron_lm_plugin.seq_length = max_position_embeddings | |
| megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length | |
| megatron_lm_plugin.megatron_lm_default_args["return_logits"] = megatron_lm_plugin.return_logits | |
| megatron_lm_plugin.megatron_lm_default_args["tokenizer_type"] = "GPT2BPETokenizer" | |
| megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name | |
| megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers | |
| megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size | |
| megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads | |
| megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings | |
| megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag | |
| megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size | |
| megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict | |
| def parse_t5_config(megatron_lm_plugin, model, batch_data): | |
| model_type_name = "t5" | |
| num_layers = model.config.num_layers | |
| hidden_size = model.config.d_model | |
| num_attention_heads = model.config.num_heads | |
| max_position_embeddings = model.config.n_positions if hasattr(model.config, "n_positions") else 1024 | |
| orig_vocab_size = model.config.vocab_size | |
| pretraining_flag = True | |
| if megatron_lm_plugin.encoder_seq_length is None: | |
| if batch_data is not None: | |
| megatron_lm_plugin.encoder_seq_length = batch_data["input_ids"].shape[1] | |
| else: | |
| megatron_lm_plugin.encoder_seq_length = max_position_embeddings | |
| if megatron_lm_plugin.decoder_seq_length is None: | |
| if batch_data is not None: | |
| megatron_lm_plugin.decoder_seq_length = batch_data["labels"].shape[1] | |
| else: | |
| megatron_lm_plugin.decoder_seq_length = max_position_embeddings | |
| megatron_lm_plugin.megatron_lm_default_args["encoder_seq_length"] = megatron_lm_plugin.encoder_seq_length | |
| megatron_lm_plugin.megatron_lm_default_args["decoder_seq_length"] = megatron_lm_plugin.decoder_seq_length | |
| megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name | |
| megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers | |
| megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size | |
| megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads | |
| megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings | |
| megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag | |
| megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size | |
| megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict | |
| def parse_llama_config(megatron_lm_plugin, model, batch_data): | |
| model_type_name = "gpt" | |
| num_layers = model.config.num_hidden_layers | |
| pretraining_flag = True | |
| hidden_size = model.config.hidden_size | |
| num_attention_heads = model.config.num_attention_heads | |
| orig_vocab_size = model.config.vocab_size | |
| max_position_embeddings = model.config.max_position_embeddings | |
| seq_length = getattr(model.config, "max_sequence_length", None) | |
| if megatron_lm_plugin.seq_length is None: | |
| if seq_length is not None: | |
| megatron_lm_plugin.seq_length = seq_length | |
| elif megatron_lm_plugin.decoder_seq_length is not None: | |
| megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length | |
| elif batch_data is not None: | |
| megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1] | |
| else: | |
| megatron_lm_plugin.seq_length = max_position_embeddings | |
| megatron_lm_plugin.megatron_lm_default_args["return_logits"] = megatron_lm_plugin.return_logits | |
| megatron_lm_plugin.megatron_lm_default_args["tokenizer_type"] = "Llama2Tokenizer" | |
| megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name | |
| megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers | |
| megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag | |
| megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size | |
| megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads | |
| megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size | |
| megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings | |
| megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length | |
| megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict | |
| def parse_glm4_moe_config(megatron_lm_plugin, model, batch_data): | |
| model_type_name = "gpt" | |
| num_layers = model.config.num_hidden_layers | |
| pretraining_flag = False | |
| hidden_size = model.config.hidden_size | |
| num_attention_heads = model.config.num_attention_heads | |
| orig_vocab_size = model.config.vocab_size | |
| max_position_embeddings = model.config.max_position_embeddings | |
| seq_length = getattr(model.config, "max_sequence_length", None) | |
| if megatron_lm_plugin.seq_length is None: | |
| if seq_length is not None: | |
| megatron_lm_plugin.seq_length = seq_length | |
| elif megatron_lm_plugin.decoder_seq_length is not None: | |
| megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length | |
| elif batch_data is not None: | |
| megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1] | |
| else: | |
| megatron_lm_plugin.seq_length = max_position_embeddings | |
| megatron_lm_plugin.megatron_lm_default_args["return_logits"] = megatron_lm_plugin.return_logits | |
| megatron_lm_plugin.megatron_lm_default_args["tokenizer_type"] = "HuggingFaceTokenizer" | |
| megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name | |
| megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers | |
| megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag | |
| megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size | |
| megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads | |
| megatron_lm_plugin.megatron_lm_default_args["kv_channels"] = model.config.head_dim | |
| megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size | |
| megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings | |
| megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length | |
| megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict | |
| megatron_lm_plugin.megatron_lm_default_args["position_embedding_type"] = "rope" | |
| megatron_lm_plugin.megatron_lm_default_args["original_model_type"] = model.config.model_type | |
| megatron_lm_plugin.megatron_lm_default_args["qk_layernorm"] = ( | |
| model.config.use_qk_norm | |
| ) # this is true for glm4.5 but False for glm4.5-air. | |
| megatron_lm_plugin.megatron_lm_default_args["add_bias_linear"] = False | |
| megatron_lm_plugin.megatron_lm_default_args["group_query_attention"] = True | |
| megatron_lm_plugin.megatron_lm_default_args["num_query_groups"] = model.config.num_key_value_heads | |
| megatron_lm_plugin.megatron_lm_default_args["ffn_hidden_size"] = model.config.intermediate_size | |
| megatron_lm_plugin.megatron_lm_default_args["add_qkv_bias"] = True | |
| megatron_lm_plugin.megatron_lm_default_args["normalization"] = "RMSNorm" | |
| megatron_lm_plugin.megatron_lm_default_args["rotary-percent"] = 0.5 | |
| megatron_lm_plugin.megatron_lm_default_args["swiglu"] = True | |
| megatron_lm_plugin.megatron_lm_default_args["moe_ffn_hidden_size"] = model.config.moe_intermediate_size | |
| megatron_lm_plugin.megatron_lm_default_args["moe_shared_expert_intermediate_size"] = ( | |
| model.config.moe_intermediate_size | |
| ) | |
| megatron_lm_plugin.megatron_lm_default_args["moe_router_pre_softmax"] = True | |
| megatron_lm_plugin.megatron_lm_default_args["moe_router_score_function"] = "sigmoid" | |
| megatron_lm_plugin.megatron_lm_default_args["moe_router_enable_expert_bias"] = True | |
| megatron_lm_plugin.megatron_lm_default_args["moe_router_bias_update_rate"] = 0 | |
| megatron_lm_plugin.megatron_lm_default_args["moe_router_load_balancing_type"] = "seq_aux_loss" | |
| megatron_lm_plugin.megatron_lm_default_args["moe_token_dispatcher_type"] = "alltoall" | |
| megatron_lm_plugin.megatron_lm_default_args["moe_router_topk"] = model.config.num_experts_per_tok | |
| megatron_lm_plugin.megatron_lm_default_args["moe_router_topk_scaling_factor"] = model.config.routed_scaling_factor | |
| megatron_lm_plugin.megatron_lm_default_args["moe_layer_freq"] = [0] * model.config.first_k_dense_replace + [1] * ( | |
| model.config.num_hidden_layers - model.config.first_k_dense_replace | |
| ) | |
| megatron_lm_plugin.megatron_lm_default_args["num_experts"] = model.config.n_routed_experts | |
| megatron_lm_plugin.megatron_lm_default_args["moe_grouped_gemm"] = True | |
| megatron_lm_plugin.megatron_lm_default_args["moe_router_dtype"] = "fp32" | |
| megatron_lm_plugin.megatron_lm_default_args["moe_permute_fusion"] = True | |
| megatron_lm_plugin.megatron_lm_default_args["moe_aux_loss_coeff"] = 0 | |
| megatron_lm_plugin.megatron_lm_default_args["rotary_base"] = model.config.rope_theta | |
| megatron_lm_plugin.megatron_lm_default_args["rope_type"] = "rope" | |
| megatron_lm_plugin.megatron_lm_default_args["rotary_percent"] = model.config.partial_rotary_factor | |
| megatron_lm_plugin.megatron_lm_default_args["norm_epsilon"] = 1e-3 | |
| megatron_lm_plugin.megatron_lm_default_args["use_flash_attn"] = True | |
| megatron_lm_plugin.megatron_lm_default_args["eos_token_id"] = model.config.eos_token_id | |
| if getattr(model.config, "fp8_param", False): | |
| megatron_lm_plugin.megatron_lm_default_args["fp8"] = model.config.fp8 | |
| megatron_lm_plugin.megatron_lm_default_args["fp8_param"] = model.config.fp8_param | |
| megatron_lm_plugin.megatron_lm_default_args["fp8_param_gather"] = model.config.fp8_param_gather | |
| megatron_lm_plugin.megatron_lm_default_args["fp8_recipe"] = model.config.fp8_recipe | |
| megatron_lm_plugin.megatron_lm_default_args["bf16"] = model.config.bf16 | |
| megatron_lm_plugin.megatron_lm_default_args[ | |
| "untie_embeddings_and_output_weights" | |
| ] = not model.config.tie_word_embeddings | |
| logger.info(f"Parsed GLM4 MoE config: {megatron_lm_plugin.megatron_lm_default_args}") | |
| class BnbQuantizationConfig: | |
| """ | |
| A plugin to enable BitsAndBytes 4bit and 8bit quantization | |
| Args: | |
| load_in_8bit (`bool`, defaults to `False`): | |
| Enable 8bit quantization. | |
| llm_int8_threshold (`float`, defaults to `6.0`): | |
| Value of the outliner threshold. Only relevant when `load_in_8bit=True`. | |
| load_in_4bit (`bool`, defaults to `False`): | |
| Enable 4bit quantization. | |
| bnb_4bit_quant_type (`str`, defaults to `fp4`): | |
| Set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','np4'}. | |
| bnb_4bit_use_double_quant (`bool`, defaults to `False`): | |
| Enable nested quantization where the quantization constants from the first quantization are quantized | |
| again. | |
| bnb_4bit_compute_dtype (`bool`, defaults to `fp16`): | |
| This sets the computational type which might be different than the input time. For example, inputs might be | |
| fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}. | |
| torch_dtype (`torch.dtype`, defaults to `None`): | |
| This sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value | |
| to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model. | |
| skip_modules (`List[str]`, defaults to `None`): | |
| An explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`. | |
| keep_in_fp32_modules (`List`, defaults to `None`): | |
| An explicit list of the modules that we don't quantize. We keep them in `torch.float32`. | |
| """ | |
| load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."}) | |
| llm_int8_threshold: float = field( | |
| default=6.0, | |
| metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"}, | |
| ) | |
| load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."}) | |
| bnb_4bit_quant_type: str = field( | |
| default="fp4", | |
| metadata={ | |
| "help": "set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','nf4'}." | |
| }, | |
| ) | |
| bnb_4bit_use_double_quant: bool = field( | |
| default=False, | |
| metadata={ | |
| "help": "enable nested quantization where the quantization constants from the first quantization are quantized again." | |
| }, | |
| ) | |
| bnb_4bit_compute_dtype: str = field( | |
| default="fp16", | |
| metadata={ | |
| "help": "This sets the computational type which might be different than the input time. For example, inputs might be " | |
| "fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}." | |
| }, | |
| ) | |
| torch_dtype: torch.dtype = field( | |
| default=None, | |
| metadata={ | |
| "help": "this sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value" | |
| "to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model " | |
| }, | |
| ) | |
| skip_modules: list[str] = field( | |
| default=None, | |
| metadata={ | |
| "help": "an explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`." | |
| }, | |
| ) | |
| keep_in_fp32_modules: list[str] = field( | |
| default=None, | |
| metadata={"help": "an explicit list of the modules that we don't quantize. We keep them in `torch.float32`."}, | |
| ) | |
| def __post_init__(self): | |
| """ | |
| Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. | |
| """ | |
| if not isinstance(self.load_in_8bit, bool): | |
| raise ValueError("load_in_8bit must be a boolean") | |
| if not isinstance(self.load_in_4bit, bool): | |
| raise ValueError("load_in_4bit must be a boolean") | |
| if self.load_in_4bit and self.load_in_8bit: | |
| raise ValueError("load_in_4bit and load_in_8bit can't be both True") | |
| if not self.load_in_4bit and not self.load_in_8bit: | |
| raise ValueError("load_in_4bit and load_in_8bit can't be both False") | |
| if not isinstance(self.llm_int8_threshold, (int, float)): | |
| raise ValueError("llm_int8_threshold must be a float or an int") | |
| if not isinstance(self.bnb_4bit_quant_type, str): | |
| raise ValueError("bnb_4bit_quant_type must be a string") | |
| elif self.bnb_4bit_quant_type not in ["fp4", "nf4"]: | |
| raise ValueError(f"bnb_4bit_quant_type must be in ['fp4','nf4'] but found {self.bnb_4bit_quant_type}") | |
| if not isinstance(self.bnb_4bit_use_double_quant, bool): | |
| raise ValueError("bnb_4bit_use_double_quant must be a boolean") | |
| if isinstance(self.bnb_4bit_compute_dtype, str): | |
| if self.bnb_4bit_compute_dtype == "fp32": | |
| self.bnb_4bit_compute_dtype = torch.float32 | |
| elif self.bnb_4bit_compute_dtype == "fp16": | |
| self.bnb_4bit_compute_dtype = torch.float16 | |
| elif self.bnb_4bit_compute_dtype == "bf16": | |
| self.bnb_4bit_compute_dtype = torch.bfloat16 | |
| else: | |
| raise ValueError( | |
| f"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}" | |
| ) | |
| elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype): | |
| raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype") | |
| if self.skip_modules is not None and not isinstance(self.skip_modules, list): | |
| raise ValueError("skip_modules must be a list of strings") | |
| if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list): | |
| raise ValueError("keep_in_fp_32_modules must be a list of strings") | |
| if self.load_in_4bit: | |
| self.target_dtype = CustomDtype.INT4 | |
| if self.load_in_8bit: | |
| self.target_dtype = torch.int8 | |
| if self.load_in_4bit and self.llm_int8_threshold != 6.0: | |
| warnings.warn("llm_int8_threshold can only be used for model loaded in 8bit") | |
| if isinstance(self.torch_dtype, str): | |
| if self.torch_dtype == "fp32": | |
| self.torch_dtype = torch.float32 | |
| elif self.torch_dtype == "fp16": | |
| self.torch_dtype = torch.float16 | |
| elif self.torch_dtype == "bf16": | |
| self.torch_dtype = torch.bfloat16 | |
| else: | |
| raise ValueError(f"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}") | |
| if self.load_in_8bit and self.torch_dtype is None: | |
| self.torch_dtype = torch.float16 | |
| if self.load_in_4bit and self.torch_dtype is None: | |
| self.torch_dtype = self.bnb_4bit_compute_dtype | |
| if not isinstance(self.torch_dtype, torch.dtype): | |
| raise ValueError("torch_dtype must be a torch.dtype") | |
| def get_module_class_from_name(module, name): | |
| """ | |
| Gets a class from a module by its name. | |
| Args: | |
| module (`torch.nn.Module`): The module to get the class from. | |
| name (`str`): The name of the class. | |
| """ | |
| modules_children = list(module.children()) | |
| if module.__class__.__name__ == name: | |
| return module.__class__ | |
| elif len(modules_children) == 0: | |
| return | |
| else: | |
| for child_module in modules_children: | |
| module_class = get_module_class_from_name(child_module, name) | |
| if module_class is not None: | |
| return module_class | |
Xet Storage Details
- Size:
- 154 kB
- Xet hash:
- f620a6e627dcbdd60a372e8b26fdbf132c96cdfcf6585573ade6b8a2879e0af7
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.