| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Accelerate utilities: Utilities related to accelerate |
| """ |
|
|
| from packaging import version |
|
|
| from .import_utils import is_accelerate_available |
|
|
|
|
| if is_accelerate_available(): |
| import accelerate |
|
|
|
|
| def apply_forward_hook(method): |
| """ |
| Decorator that applies a registered CpuOffload hook to an arbitrary function rather than `forward`. This is useful |
| for cases where a PyTorch module provides functions other than `forward` that should trigger a move to the |
| appropriate acceleration device. This is the case for `encode` and `decode` in [`AutoencoderKL`]. |
| |
| This decorator looks inside the internal `_hf_hook` property to find a registered offload hook. |
| |
| :param method: The method to decorate. This method should be a method of a PyTorch module. |
| """ |
| if not is_accelerate_available(): |
| return method |
| accelerate_version = version.parse(accelerate.__version__).base_version |
| if version.parse(accelerate_version) < version.parse("0.17.0"): |
| return method |
|
|
| def wrapper(self, *args, **kwargs): |
| if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"): |
| self._hf_hook.pre_forward(self) |
| return method(self, *args, **kwargs) |
|
|
| return wrapper |
|
|