| | import torch |
| | from torch.func import functional_call |
| | import queue |
| | import threading |
| | from typing import Dict, List, Any |
| | import omegaconf |
| | from pydantic import BaseModel, validator |
| | from typing import Optional |
| | from functools import wraps |
| |
|
| | def _callable_once(func): |
| | @wraps(func) |
| | def wrapper(self, *args, **kwargs): |
| | method_called_flag = f"_called_once_{func.__name__}" |
| | if getattr(self, method_called_flag, False): |
| | raise RuntimeError(f"{func.__name__} can only be called once.") |
| | setattr(self, method_called_flag, True) |
| | return func(self, *args, **kwargs) |
| | return wrapper |
| |
|
| | class OffloadCleanCacheWrapperParam(BaseModel): |
| | module: Any |
| | method_name: str |
| | diff_mem_gb_thre: float |
| |
|
| | class OffloadParam(BaseModel): |
| | offload_module: Any |
| | cpu_mem_gb: float |
| | pre_copy_step: Optional[int] = None |
| | clean_cache_after_forward: Optional[bool] = None |
| | dtype: Optional[str] = None |
| | offload_layer_dict: Dict[str, int] = {} |
| | ignore_layer_list: List[str] = [] |
| | clean_cache_wrapper: Optional[OffloadCleanCacheWrapperParam] = None |
| | debug: Optional[bool] = None |
| |
|
| | @validator('dtype') |
| | def parse_dtype(cls, value): |
| | if value is None: |
| | return None |
| | dtype_map = { |
| | 'torch.float16': torch.float16, |
| | 'torch.float32': torch.float32, |
| | 'torch.float64': torch.float64, |
| | 'torch.int64': torch.int64, |
| | } |
| | if value not in dtype_map: |
| | raise ValueError(f"Unsupported dtype: {value}") |
| | return dtype_map[value] |
| | |
| | def init_param_dict(self): |
| | param_dict = {} |
| | param_dict['cpu_mem_gb'] = self.cpu_mem_gb |
| | if self.pre_copy_step is not None: |
| | param_dict['pre_copy_step'] = self.pre_copy_step |
| | if self.clean_cache_after_forward is not None: |
| | param_dict['clean_cache_after_forward'] = self.clean_cache_after_forward |
| | if self.debug is not None: |
| | param_dict['debug'] = self.debug |
| | |
| | return param_dict |
| | |
| | def offload_layer_param_dict(self): |
| | param_dict = {} |
| | param_dict['module'] = self.offload_module |
| | param_dict['offload_layer_dict'] = self.offload_layer_dict |
| | param_dict['ignore_layer_list'] = self.ignore_layer_list |
| | param_dict['dtype'] = self.dtype |
| |
|
| | return param_dict |
| | |
| | def clean_cache_param_dict(self): |
| | param_dict = {} |
| | if self.clean_cache_wrapper is not None: |
| | param_dict['module'] = self.clean_cache_wrapper.module |
| | param_dict['method_name'] = self.clean_cache_wrapper.method_name |
| | param_dict['diff_mem_gb_thre'] = self.clean_cache_wrapper.diff_mem_gb_thre |
| |
|
| | return param_dict |
| | |
| | @staticmethod |
| | def recursive_print(model, indent=0): |
| | for field_name, field_info in model.__fields__.items(): |
| | field_value = getattr(model, field_name) |
| | print(" " * indent + f"{field_name}:") |
| |
|
| | if issubclass(type(field_value), BaseModel): |
| | print(" " * (indent + 2) + f"--- Nested model: {field_value.__class__.__name__}") |
| | OffloadParam.recursive_print(field_value, indent + 4) |
| | else: |
| | print(" " * (indent + 2) + f"class: {field_value.__class__.__name__}") |
| | if isinstance(field_value, torch.nn.Module): |
| | pass |
| | else: |
| | print(" " * (indent + 2) + f"value: {field_value}") |
| |
|
| | def show(self): |
| | print("-"*20 + "[OffloadParam]" + "-"*20) |
| | OffloadParam.recursive_print(self) |
| | print("-"*40) |
| |
|
| |
|
| | class OffloadParamParse: |
| | def __init__(self): |
| | pass |
| |
|
| | @staticmethod |
| | def _get_model(root_model: torch.nn.Module, model_dir: str): |
| | assert(model_dir.startswith("self")), f"model_dir {model_dir} must startswith `self`" |
| | model = root_model |
| | for layer in model_dir.split('.'): |
| | if layer == "self": |
| | continue |
| | assert(hasattr(model, layer)), f"model not has layer [{layer}]!" |
| | model = getattr(model, layer) |
| | return model |
| |
|
| | @staticmethod |
| | def parse_config(root_model: torch.nn.Module, cfg: omegaconf.DictConfig)->OffloadParam: |
| | assert(hasattr(cfg, "offload_module") and hasattr(cfg, "cpu_mem_gb") and hasattr(cfg, "dtype")) |
| | |
| | offload_module = OffloadParamParse._get_model(root_model, cfg.offload_module) |
| | cpu_mem_gb = cfg.cpu_mem_gb |
| | dtype = cfg.dtype |
| |
|
| | pre_copy_step = cfg.pre_copy_step \ |
| | if hasattr(cfg, "pre_copy_step") else None |
| |
|
| | clean_cache_after_forward = cfg.clean_cache_after_forward \ |
| | if hasattr(cfg, "clean_cache_after_forward") else None |
| | |
| | offload_layer_dict = {k: v for k, v in cfg.offload_layer_dict.items()} \ |
| | if hasattr(cfg, "offload_layer_dict") else {} |
| |
|
| | ignore_layer_list = cfg.ignore_layer_list \ |
| | if hasattr(cfg, "ignore_layer_list") else [] |
| | |
| | debug = cfg.debug if hasattr(cfg, "debug") else None |
| | |
| | clean_cache_wrapper = None |
| | if hasattr(cfg, "clean_cache_wrapper"): |
| | clean_cache_cfg = cfg.clean_cache_wrapper |
| | cc_module = OffloadParamParse._get_model(root_model, clean_cache_cfg.module) |
| | cc_method_name = clean_cache_cfg.method_name |
| | diff_mem_gb_thre = clean_cache_cfg.diff_mem_gb_thre |
| | clean_cache_wrapper = OffloadCleanCacheWrapperParam( |
| | module=cc_module, |
| | method_name=cc_method_name, |
| | diff_mem_gb_thre=diff_mem_gb_thre) |
| | |
| | return OffloadParam( |
| | offload_module=offload_module, |
| | cpu_mem_gb=cpu_mem_gb, |
| | pre_copy_step=pre_copy_step, |
| | clean_cache_after_forward=clean_cache_after_forward, |
| | dtype=dtype, |
| | offload_layer_dict=offload_layer_dict, |
| | ignore_layer_list=ignore_layer_list, |
| | clean_cache_wrapper=clean_cache_wrapper, |
| | debug=debug |
| | ) |
| |
|
| |
|
| | class LayerParamStruct: |
| | def __init__(self): |
| | self.count = 0 |
| | self.device_state = None |
| |
|
| |
|
| | class OffloadProfiler: |
| | def __init__(self, device_index=0, cpu_mem_gb=-1, pre_copy_step=1, clean_cache_after_forward=False, debug=False): |
| | self.clean_cache_after_forward = clean_cache_after_forward |
| | self.cpu_mem_gb = cpu_mem_gb |
| | self.cpu_mem_b_count = 0 |
| | self.device_index = device_index |
| | self.execution_order = [] |
| | self.execution_order_idx = {} |
| | self.pin_memory = False |
| | test_data = torch.rand(1,1, device='cpu') |
| | pin_data = test_data.pin_memory() |
| | self.pin_memory = pin_data.is_pinned() |
| | print(f"pin:{self.pin_memory}") |
| | self.copy_stream = torch.cuda.Stream() |
| | self.copy_queue = queue.Queue() |
| | self.layer_param:Dict[str, LayerParamStruct] = {} |
| | self.model_map = {} |
| | self.stop_flag = False |
| | self.copy_condition = threading.Condition() |
| | self.queue_condition = threading.Condition() |
| | self.mem_line_b = 0 |
| |
|
| | self.copy_thread = threading.Thread(target=self._copy_thread_fun) |
| | self.copy_thread.daemon = True |
| | self.copy_thread.start() |
| |
|
| | self.cur_copy_idx = 0 |
| | self.execute_over = False |
| | self.pre_copy_step = pre_copy_step |
| |
|
| | self.tmp_state_list = [] |
| | self.tmp_state_idx = 0 |
| | for i in range(pre_copy_step + 2): |
| | self.tmp_state_list.append(None) |
| |
|
| | self.debug = debug |
| |
|
| | def stop(self): |
| | self.stop_flag = True |
| | with self.queue_condition: |
| | self.queue_condition.notify() |
| | self.copy_thread.join() |
| |
|
| | del self.layer_param |
| | del self.model_map |
| | del self.copy_stream |
| |
|
| | def _copy_thread_fun(self): |
| | while self.stop_flag == False: |
| | layer_name = "--" |
| | with self.queue_condition: |
| | while self.copy_queue.qsize() == 0 and self.stop_flag == False: |
| | self.queue_condition.wait() |
| | if self.stop_flag == True: |
| | break |
| | layer_name = self.copy_queue.get() |
| | with torch.cuda.stream(self.copy_stream): |
| | if layer_name in self.model_map: |
| | model = self.model_map[layer_name] |
| | self.tmp_state_list[self.tmp_state_idx] = { |
| | k: v.to(torch.device(f"cuda:{self.device_index}"), non_blocking=False) |
| | for k, v in model.state_dict().items() |
| | } |
| | self.copy_stream.synchronize() |
| |
|
| | device_state = self.tmp_state_list[self.tmp_state_idx] |
| | self.tmp_state_idx = (self.tmp_state_idx + 1) % len(self.tmp_state_list) |
| |
|
| | with self.copy_condition: |
| | if layer_name in self.layer_param: |
| | self.layer_param[layer_name].count += 1 |
| | else: |
| | self.layer_param[layer_name] = LayerParamStruct() |
| | self.layer_param[layer_name].count = 1 |
| | self.layer_param[layer_name].device_state = device_state |
| | self.copy_condition.notify() |
| | else: |
| | print(f"get model error! {layer_name}") |
| | print("copy thread stop..") |
| |
|
| | def _get_new_step_copy_begin_end(self, tag_name): |
| | |
| | pre_copy_step = self.pre_copy_step |
| | pre_copy_step = min(pre_copy_step, len(self.execution_order) // 2) |
| | |
| | cur_exe_idx = self.execution_order_idx[tag_name] |
| | copy_begin = self.cur_copy_idx |
| | copy_end = cur_exe_idx + pre_copy_step + 1 |
| | if copy_end - copy_begin > len(self.execution_order): |
| | copy_end %= len(self.execution_order) |
| | if copy_end - copy_begin > pre_copy_step + 1 or copy_end - copy_begin < 0: |
| | |
| | self.cur_copy_idx = cur_exe_idx |
| | copy_begin, copy_end = self._get_new_step_copy_begin_end(tag_name=tag_name) |
| | return copy_begin, copy_end |
| | |
| | def make_forward_wrapper(self, module, tag_name, ignore_layer_list=[]): |
| | original_forward = module.forward |
| | layer_param_size = 0 |
| | for name, param in module.named_parameters(): |
| | layer_param_size += param.data.numel() * param.data.element_size() / 1024 / 1024 |
| | |
| | taget_cpu_mem_b = self.cpu_mem_gb * 1024 * 1024 * 1024 |
| | offload = False |
| | for name, param in module.named_parameters(): |
| | p_name = f"{tag_name}.{name}" if tag_name else name |
| | for i_layer in ignore_layer_list: |
| | if p_name.startswith(i_layer): |
| | if self.debug: |
| | print(f"ignore layer param: {p_name}") |
| | continue |
| |
|
| | if taget_cpu_mem_b >= 0 and self.cpu_mem_b_count >= taget_cpu_mem_b: |
| | break |
| | cpu_data = torch.empty_strided(size=param.data.size(), |
| | stride=param.data.stride(), |
| | dtype=param.data.dtype, |
| | layout=param.data.layout, |
| | device='cpu', |
| | pin_memory=self.pin_memory) |
| | cpu_data.copy_(param.data) |
| | param.data = cpu_data |
| |
|
| | param_size = param.data.numel() * param.data.element_size() |
| | self.cpu_mem_b_count += param_size |
| | offload = True |
| | if self.debug: |
| | print(f"layer: {tag_name}, type: {module.__class__.__name__}, size(MB): {layer_param_size}, offload: {offload}, sum_offload_size(MB): {self.cpu_mem_b_count/1024/1024}") |
| | |
| | if offload: |
| | copy_condition = self.copy_condition |
| | queue_condition = self.queue_condition |
| | copy_queue = self.copy_queue |
| | layer_param = self.layer_param |
| | def forward_wrapper(*args, **kwargs): |
| | module.forward = original_forward |
| |
|
| | execute_over = False if tag_name not in self.execution_order_idx else True |
| | if execute_over == False: |
| | self.model_map[tag_name] = module |
| | self.execution_order.append(tag_name) |
| | self.execution_order_idx[tag_name] = len(self.execution_order) - 1 |
| | copy_queue.put(tag_name) |
| | with queue_condition: |
| | queue_condition.notify() |
| | else: |
| | |
| | copy_begin, copy_end = self._get_new_step_copy_begin_end(tag_name=tag_name) |
| | if copy_end > copy_begin: |
| | for idx in range(copy_begin, copy_end): |
| | idx = idx % len(self.execution_order) |
| | copy_tag_name = self.execution_order[idx] |
| | copy_queue.put(copy_tag_name) |
| | with queue_condition: |
| | queue_condition.notify() |
| |
|
| | self.cur_copy_idx = copy_end % len(self.execution_order) |
| | |
| | run_state = None |
| | with self.copy_condition: |
| | while tag_name not in self.layer_param: |
| | copy_condition.wait() |
| | run_state = self.layer_param[tag_name].device_state |
| | self.layer_param[tag_name].count -= 1 |
| | |
| | module.eval() |
| | with torch.no_grad(): |
| | output = functional_call(module, run_state, args=args, kwargs=kwargs) |
| | with self.copy_condition: |
| | if self.layer_param[tag_name].count == 0: |
| | del self.layer_param[tag_name] |
| | diff_mem_b_thre = 1 * (1024 ** 3) |
| | if self.clean_cache_after_forward: |
| | reserved = torch.cuda.memory_reserved() |
| | if reserved > self.mem_line_b: |
| | torch.cuda.empty_cache() |
| | cur_reserved = torch.cuda.memory_reserved() |
| | diff_mem = reserved - cur_reserved |
| | if diff_mem > diff_mem_b_thre: |
| | self.mem_line_b = cur_reserved + (reserved - cur_reserved) / 2 + 10 |
| | else: |
| | self.mem_line_b = reserved + 10 |
| | if self.debug: |
| | print(f"child mem line update, clean cache:{reserved/1024/1024}, cur mem: {cur_reserved/1024/1024} new limit: {self.mem_line_b / 1024 / 1024}, child name: {tag_name}") |
| | |
| | module.forward = forward_wrapper |
| | return output |
| | module.forward = forward_wrapper |
| | |
| | torch.cuda.empty_cache() |
| | return module |
| | |
| | def reset_empty_cache_mem_line(self): |
| | self.mem_line_b = 0 |
| | torch.cuda.empty_cache() |
| | |
| | def clean_cache_wrapper(self, module, method_name='', diff_mem_gb_thre=1): |
| | if not hasattr(module, method_name) or not callable(getattr(module, method_name)): |
| | print(f"no this method {method_name}") |
| | return module |
| | |
| | original_fun = getattr(module, method_name) |
| | diff_mem_b_thre = diff_mem_gb_thre * (1024 ** 3) |
| | self.reset_empty_cache_mem_line() |
| |
|
| | def clean_wrapper(*args, **kwargs): |
| | setattr(module, method_name, original_fun) |
| | output = original_fun(*args, **kwargs) |
| | reserved = torch.cuda.memory_reserved() |
| | if reserved > self.mem_line_b: |
| | torch.cuda.empty_cache() |
| | cur_reserved = torch.cuda.memory_reserved() |
| | diff_mem = reserved - cur_reserved |
| | if diff_mem > diff_mem_b_thre: |
| | self.mem_line_b = cur_reserved + (reserved - cur_reserved) / 2 + 10 |
| | else: |
| | self.mem_line_b = reserved + 10 |
| |
|
| | if self.debug: |
| | print(f"mem line update, clean cache:{reserved/1024/1024}, cur mem: {cur_reserved/1024/1024} new limit: {self.mem_line_b / 1024 / 1024}") |
| | setattr(module, method_name, clean_wrapper) |
| | return output |
| | |
| | setattr(module, method_name, clean_wrapper) |
| | return module |
| | |
| | @_callable_once |
| | def offload_layer(self, module, offload_layer_dict={}, ignore_layer_list=[], dtype:torch.dtype = None): |
| | return self._offload_layer( |
| | module=module, |
| | tag="", |
| | offload_layer_dict=offload_layer_dict, |
| | ignore_layer_list=ignore_layer_list, |
| | dtype=dtype |
| | ) |
| | |
| | def _offload_layer(self, module, tag="", offload_layer_dict={}, ignore_layer_list=[], dtype:torch.dtype = None): |
| | """ |
| | Offload specific layers of a PyTorch model to a specified depth. |
| | A model can only be offloaded once. |
| | |
| | Args: |
| | module (torch.nn.Module): |
| | The PyTorch model containing the layers to offload. This is the model that will be modified in place. |
| | |
| | tag (str, optional): |
| | A string identifier for the model. |
| | Default is an empty string. |
| | |
| | offload_layer_dict (dict, optional): |
| | A dictionary where keys are layer names and values represent the depth at which the offloading should occur. |
| | For example, |
| | ```offload_layer_dict = {'cfm_wrapper': 5, 'hubert': 4}``` means that the `cfm_wrapper` layer should |
| | be offloaded at depth 5, and the `hubert` layer should be offloaded at depth 4. |
| | Default is an empty dictionary. |
| | |
| | ignore_layer_list (list, optional): |
| | A list of layer names or parameter identifiers to be ignored during the offloading process. |
| | Layers in this list will not be offloaded, even if they are present in the `offload_layer_dict`. |
| | For example, |
| | ```ignore_layer_list = ['cfm_wrapper.estimator.h', 'cfm_wrapper.estimator.adaln_single']``` |
| | means that layers starting with `cfm_wrapper.estimator.h` or 'cfm_wrapper.estimator.adaln_single' will not be offload. |
| | Default is an empty list. |
| | |
| | dtype (torch.dtype, optional): |
| | The data type (e.g., `torch.float16`, `torch.float32`) to which the offloaded layers should be converted. |
| | If `None`, the data type of the layers will remain unchanged. Default is `None`. |
| | |
| | Returns: |
| | None |
| | """ |
| | for p in module._parameters.values(): |
| | if p is not None: |
| | p.data = p.data.to(torch.device(f"cuda:{self.device_index}")) |
| | if dtype is not None: |
| | p.data = p.data.to(dtype) |
| | for b in module._buffers.values(): |
| | if b is not None: |
| | b.data = b.data.to(torch.device(f"cuda:{self.device_index}")) |
| | if dtype is not None: |
| | b.data = b.data.to(dtype) |
| | for attr_name, attr in module.__dict__.items(): |
| | if isinstance(attr, torch.Tensor) and not attr_name.startswith('_'): |
| | attr.data = attr.data.to(torch.device(f"cuda:{self.device_index}")) |
| | if dtype is not None: |
| | attr.data = attr.data.to(dtype) |
| |
|
| | for name, child in module.named_children(): |
| | current_tag = f"{tag}.{name}" if tag else name |
| | child = child.to(torch.device(f"cuda:{self.device_index}")) |
| | if dtype is not None: |
| | child = child.to(dtype) |
| |
|
| | torch.cuda.empty_cache() |
| | setattr(module, name, child) |
| | pre_name = current_tag.split('.')[0] |
| | if pre_name not in offload_layer_dict: |
| | param_size = 0 |
| | for p in child.parameters(): |
| | param_size += p.data.numel() * p.data.element_size() |
| | param_size = param_size / 1024 / 1024 |
| | if self.debug: |
| | print(f"not offload layer {current_tag}, size: {param_size}MB") |
| | continue |
| | |
| | has_children = any(child.named_children()) |
| | layer_count = current_tag.count('.') + 1 |
| | |
| | layer_deep = offload_layer_dict[pre_name] |
| | if layer_count >= layer_deep: |
| | has_children = False |
| | |
| | if has_children: |
| | self._offload_layer(module=child, |
| | tag=current_tag, |
| | offload_layer_dict=offload_layer_dict, |
| | ignore_layer_list=ignore_layer_list, |
| | dtype=dtype) |
| | continue |
| |
|
| | ignore = False |
| | for i_layer in ignore_layer_list: |
| | if current_tag.startswith(i_layer): |
| | ignore = True |
| | if self.debug: |
| | print(f"ignore layer offload: {current_tag}") |
| | break |
| | |
| | if hasattr(child, "forward") and not ignore: |
| | child = self.make_forward_wrapper( |
| | child, current_tag, ignore_layer_list=ignore_layer_list |
| | ) |
| | return module |
| | |
| | def get_execution_order(self): |
| | return self.execution_order |
| |
|