kelseye's picture
Upload folder using huggingface_hub
580dfb0 verified
import torch
from diffsynth import load_state_dict
from safetensors.torch import save_file
class SingleKVCacheModel(torch.nn.Module):
def __init__(self, shape):
super().__init__()
self.k = torch.nn.Parameter(torch.zeros(shape))
self.v = torch.nn.Parameter(torch.zeros(shape))
def forward(self):
return (self.k, self.v)
class StaticKVCacheModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.block_names = [f"double_{i}" for i in range(5)] + [f"single_{i}" for i in range(20)]
self.cache = torch.nn.ModuleList([SingleKVCacheModel((1, 4608, 24, 128)) for _ in self.block_names])
def load_from_kv_cache(self, kv_cache):
state_dict = {}
for block_id, block_name in enumerate(self.block_names):
state_dict[f"cache.{block_id}.k"] = kv_cache[block_name][0]
state_dict[f"cache.{block_id}.v"] = kv_cache[block_name][1]
self.load_state_dict(state_dict)
@torch.no_grad()
def process_inputs(self, **kwargs):
return {}
def forward(self, **kwargs):
kv_cache = {}
for block_name, cache in zip(self.block_names, self.cache):
kv_cache[block_name] = cache()
return {"kv_cache": kv_cache}
def convert_from_kv_cache(kv_cache, path):
model = StaticKVCacheModel().to(torch.bfloat16)
model.load_from_kv_cache(kv_cache)
save_file(model.state_dict(), path)
class DataAnnotator:
def __call__(self, **kwargs):
return kwargs
TEMPLATE_MODEL = StaticKVCacheModel
TEMPLATE_MODEL_PATH = "model.safetensors"
TEMPLATE_DATA_PROCESSOR = DataAnnotator