| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Dict |
|
|
| import torch |
|
|
|
|
| class AttnProcsLayers(torch.nn.Module): |
| def __init__(self, state_dict: Dict[str, torch.Tensor]): |
| super().__init__() |
| self.layers = torch.nn.ModuleList(state_dict.values()) |
| self.mapping = dict(enumerate(state_dict.keys())) |
| self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} |
|
|
| |
| self.split_keys = [".processor", ".self_attn"] |
|
|
| |
| |
| def map_to(module, state_dict, *args, **kwargs): |
| new_state_dict = {} |
| for key, value in state_dict.items(): |
| num = int(key.split(".")[1]) |
| new_key = key.replace(f"layers.{num}", module.mapping[num]) |
| new_state_dict[new_key] = value |
|
|
| return new_state_dict |
|
|
| def remap_key(key, state_dict): |
| for k in self.split_keys: |
| if k in key: |
| return key.split(k)[0] + k |
|
|
| raise ValueError( |
| f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}." |
| ) |
|
|
| def map_from(module, state_dict, *args, **kwargs): |
| all_keys = list(state_dict.keys()) |
| for key in all_keys: |
| replace_key = remap_key(key, state_dict) |
| new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") |
| state_dict[new_key] = state_dict[key] |
| del state_dict[key] |
|
|
| self._register_state_dict_hook(map_to) |
| self._register_load_state_dict_pre_hook(map_from, with_module=True) |
|
|