File size: 1,648 Bytes
580dfb0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | import torch
from diffsynth import load_state_dict
from safetensors.torch import save_file
class SingleKVCacheModel(torch.nn.Module):
def __init__(self, shape):
super().__init__()
self.k = torch.nn.Parameter(torch.zeros(shape))
self.v = torch.nn.Parameter(torch.zeros(shape))
def forward(self):
return (self.k, self.v)
class StaticKVCacheModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.block_names = [f"double_{i}" for i in range(5)] + [f"single_{i}" for i in range(20)]
self.cache = torch.nn.ModuleList([SingleKVCacheModel((1, 4608, 24, 128)) for _ in self.block_names])
def load_from_kv_cache(self, kv_cache):
state_dict = {}
for block_id, block_name in enumerate(self.block_names):
state_dict[f"cache.{block_id}.k"] = kv_cache[block_name][0]
state_dict[f"cache.{block_id}.v"] = kv_cache[block_name][1]
self.load_state_dict(state_dict)
@torch.no_grad()
def process_inputs(self, **kwargs):
return {}
def forward(self, **kwargs):
kv_cache = {}
for block_name, cache in zip(self.block_names, self.cache):
kv_cache[block_name] = cache()
return {"kv_cache": kv_cache}
def convert_from_kv_cache(kv_cache, path):
model = StaticKVCacheModel().to(torch.bfloat16)
model.load_from_kv_cache(kv_cache)
save_file(model.state_dict(), path)
class DataAnnotator:
def __call__(self, **kwargs):
return kwargs
TEMPLATE_MODEL = StaticKVCacheModel
TEMPLATE_MODEL_PATH = "model.safetensors"
TEMPLATE_DATA_PROCESSOR = DataAnnotator
|