import torch from diffsynth import load_state_dict from safetensors.torch import save_file class SingleKVCacheModel(torch.nn.Module): def __init__(self, shape): super().__init__() self.k = torch.nn.Parameter(torch.zeros(shape)) self.v = torch.nn.Parameter(torch.zeros(shape)) def forward(self): return (self.k, self.v) class StaticKVCacheModel(torch.nn.Module): def __init__(self): super().__init__() self.block_names = [f"double_{i}" for i in range(5)] + [f"single_{i}" for i in range(20)] self.cache = torch.nn.ModuleList([SingleKVCacheModel((1, 4608, 24, 128)) for _ in self.block_names]) def load_from_kv_cache(self, kv_cache): state_dict = {} for block_id, block_name in enumerate(self.block_names): state_dict[f"cache.{block_id}.k"] = kv_cache[block_name][0] state_dict[f"cache.{block_id}.v"] = kv_cache[block_name][1] self.load_state_dict(state_dict) @torch.no_grad() def process_inputs(self, **kwargs): return {} def forward(self, **kwargs): kv_cache = {} for block_name, cache in zip(self.block_names, self.cache): kv_cache[block_name] = cache() return {"kv_cache": kv_cache} def convert_from_kv_cache(kv_cache, path): model = StaticKVCacheModel().to(torch.bfloat16) model.load_from_kv_cache(kv_cache) save_file(model.state_dict(), path) class DataAnnotator: def __call__(self, **kwargs): return kwargs TEMPLATE_MODEL = StaticKVCacheModel TEMPLATE_MODEL_PATH = "model.safetensors" TEMPLATE_DATA_PROCESSOR = DataAnnotator