| import torch |
| from einops import rearrange, repeat |
| from .tiler import TileWorker2Dto3D |
|
|
|
|
|
|
| class Downsample3D(torch.nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int = 3, |
| stride: int = 2, |
| padding: int = 0, |
| compress_time: bool = False, |
| ): |
| super().__init__() |
|
|
| self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) |
| self.compress_time = compress_time |
|
|
| def forward(self, x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor: |
| if self.compress_time: |
| batch_size, channels, frames, height, width = x.shape |
|
|
| |
| x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames) |
|
|
| if x.shape[-1] % 2 == 1: |
| x_first, x_rest = x[..., 0], x[..., 1:] |
| if x_rest.shape[-1] > 0: |
| |
| x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2) |
|
|
| x = torch.cat([x_first[..., None], x_rest], dim=-1) |
| |
| x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) |
| else: |
| |
| x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2) |
| |
| x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) |
|
|
| |
| pad = (0, 1, 0, 1) |
| x = torch.nn.functional.pad(x, pad, mode="constant", value=0) |
| batch_size, channels, frames, height, width = x.shape |
| |
| x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width) |
| x = self.conv(x) |
| |
| x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) |
| return x |
|
|
|
|
|
|
| class Upsample3D(torch.nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int = 3, |
| stride: int = 1, |
| padding: int = 1, |
| compress_time: bool = False, |
| ) -> None: |
| super().__init__() |
| self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) |
| self.compress_time = compress_time |
|
|
| def forward(self, inputs: torch.Tensor, xq: torch.Tensor) -> torch.Tensor: |
| if self.compress_time: |
| if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1: |
| |
| x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:] |
|
|
| x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0) |
| x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0) |
| x_first = x_first[:, :, None, :, :] |
| inputs = torch.cat([x_first, x_rest], dim=2) |
| elif inputs.shape[2] > 1: |
| inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0) |
| else: |
| inputs = inputs.squeeze(2) |
| inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0) |
| inputs = inputs[:, :, None, :, :] |
| else: |
| |
| b, c, t, h, w = inputs.shape |
| inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) |
| inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0) |
| inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4) |
|
|
| b, c, t, h, w = inputs.shape |
| inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) |
| inputs = self.conv(inputs) |
| inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4) |
|
|
| return inputs |
|
|
|
|
|
|
| class CogVideoXSpatialNorm3D(torch.nn.Module): |
| def __init__(self, f_channels, zq_channels, groups): |
| super().__init__() |
| self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) |
| self.conv_y = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1) |
| self.conv_b = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1) |
|
|
|
|
| def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: |
| if f.shape[2] > 1 and f.shape[2] % 2 == 1: |
| f_first, f_rest = f[:, :, :1], f[:, :, 1:] |
| f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:] |
| z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] |
| z_first = torch.nn.functional.interpolate(z_first, size=f_first_size) |
| z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size) |
| zq = torch.cat([z_first, z_rest], dim=2) |
| else: |
| zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:]) |
|
|
| norm_f = self.norm_layer(f) |
| new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) |
| return new_f |
|
|
|
|
|
|
| class Resnet3DBlock(torch.nn.Module): |
| def __init__(self, in_channels, out_channels, spatial_norm_dim, groups, eps=1e-6, use_conv_shortcut=False): |
| super().__init__() |
| self.nonlinearity = torch.nn.SiLU() |
| if spatial_norm_dim is None: |
| self.norm1 = torch.nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) |
| self.norm2 = torch.nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) |
| else: |
| self.norm1 = CogVideoXSpatialNorm3D(in_channels, spatial_norm_dim, groups) |
| self.norm2 = CogVideoXSpatialNorm3D(out_channels, spatial_norm_dim, groups) |
|
|
| self.conv1 = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1)) |
|
|
| self.conv2 = CachedConv3d(out_channels, out_channels, kernel_size=3, padding=(0, 1, 1)) |
|
|
| if in_channels != out_channels: |
| if use_conv_shortcut: |
| self.conv_shortcut = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1)) |
| else: |
| self.conv_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1) |
| else: |
| self.conv_shortcut = lambda x: x |
|
|
|
|
| def forward(self, hidden_states, zq): |
| residual = hidden_states |
|
|
| hidden_states = self.norm1(hidden_states, zq) if isinstance(self.norm1, CogVideoXSpatialNorm3D) else self.norm1(hidden_states) |
| hidden_states = self.nonlinearity(hidden_states) |
| hidden_states = self.conv1(hidden_states) |
|
|
| hidden_states = self.norm2(hidden_states, zq) if isinstance(self.norm2, CogVideoXSpatialNorm3D) else self.norm2(hidden_states) |
| hidden_states = self.nonlinearity(hidden_states) |
| hidden_states = self.conv2(hidden_states) |
|
|
| hidden_states = hidden_states + self.conv_shortcut(residual) |
|
|
| return hidden_states |
| |
|
|
|
|
| class CachedConv3d(torch.nn.Conv3d): |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): |
| super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) |
| self.cached_tensor = None |
|
|
|
|
| def clear_cache(self): |
| self.cached_tensor = None |
| |
|
|
| def forward(self, input: torch.Tensor, use_cache = True) -> torch.Tensor: |
| if use_cache: |
| if self.cached_tensor is None: |
| self.cached_tensor = torch.concat([input[:, :, :1]] * 2, dim=2) |
| input = torch.concat([self.cached_tensor, input], dim=2) |
| self.cached_tensor = input[:, :, -2:] |
| return super().forward(input) |
|
|
|
|
|
|
| class CogVAEDecoder(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.scaling_factor = 0.7 |
| self.conv_in = CachedConv3d(16, 512, kernel_size=3, stride=1, padding=(0, 1, 1)) |
|
|
| self.blocks = torch.nn.ModuleList([ |
| Resnet3DBlock(512, 512, 16, 32), |
| Resnet3DBlock(512, 512, 16, 32), |
| Resnet3DBlock(512, 512, 16, 32), |
| Resnet3DBlock(512, 512, 16, 32), |
| Resnet3DBlock(512, 512, 16, 32), |
| Resnet3DBlock(512, 512, 16, 32), |
| Upsample3D(512, 512, compress_time=True), |
| Resnet3DBlock(512, 256, 16, 32), |
| Resnet3DBlock(256, 256, 16, 32), |
| Resnet3DBlock(256, 256, 16, 32), |
| Resnet3DBlock(256, 256, 16, 32), |
| Upsample3D(256, 256, compress_time=True), |
| Resnet3DBlock(256, 256, 16, 32), |
| Resnet3DBlock(256, 256, 16, 32), |
| Resnet3DBlock(256, 256, 16, 32), |
| Resnet3DBlock(256, 256, 16, 32), |
| Upsample3D(256, 256, compress_time=False), |
| Resnet3DBlock(256, 128, 16, 32), |
| Resnet3DBlock(128, 128, 16, 32), |
| Resnet3DBlock(128, 128, 16, 32), |
| Resnet3DBlock(128, 128, 16, 32), |
| ]) |
|
|
| self.norm_out = CogVideoXSpatialNorm3D(128, 16, 32) |
| self.conv_act = torch.nn.SiLU() |
| self.conv_out = CachedConv3d(128, 3, kernel_size=3, stride=1, padding=(0, 1, 1)) |
|
|
|
|
| def forward(self, sample): |
| sample = sample / self.scaling_factor |
| hidden_states = self.conv_in(sample) |
|
|
| for block in self.blocks: |
| hidden_states = block(hidden_states, sample) |
| |
| hidden_states = self.norm_out(hidden_states, sample) |
| hidden_states = self.conv_act(hidden_states) |
| hidden_states = self.conv_out(hidden_states) |
|
|
| return hidden_states |
| |
|
|
| def decode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x): |
| if tiled: |
| B, C, T, H, W = sample.shape |
| return TileWorker2Dto3D().tiled_forward( |
| forward_fn=lambda x: self.decode_small_video(x), |
| model_input=sample, |
| tile_size=tile_size, tile_stride=tile_stride, |
| tile_device=sample.device, tile_dtype=sample.dtype, |
| computation_device=sample.device, computation_dtype=sample.dtype, |
| scales=(3/16, (T//2*8+T%2)/T, 8, 8), |
| progress_bar=progress_bar |
| ) |
| else: |
| return self.decode_small_video(sample) |
| |
|
|
| def decode_small_video(self, sample): |
| B, C, T, H, W = sample.shape |
| computation_device = self.conv_in.weight.device |
| computation_dtype = self.conv_in.weight.dtype |
| value = [] |
| for i in range(T//2): |
| tl = i*2 + T%2 - (T%2 and i==0) |
| tr = i*2 + 2 + T%2 |
| model_input = sample[:, :, tl: tr, :, :].to(dtype=computation_dtype, device=computation_device) |
| model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device) |
| value.append(model_output) |
| value = torch.concat(value, dim=2) |
| for name, module in self.named_modules(): |
| if isinstance(module, CachedConv3d): |
| module.clear_cache() |
| return value |
| |
|
|
| @staticmethod |
| def state_dict_converter(): |
| return CogVAEDecoderStateDictConverter() |
| |
|
|
|
|
| class CogVAEEncoder(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.scaling_factor = 0.7 |
| self.conv_in = CachedConv3d(3, 128, kernel_size=3, stride=1, padding=(0, 1, 1)) |
|
|
| self.blocks = torch.nn.ModuleList([ |
| Resnet3DBlock(128, 128, None, 32), |
| Resnet3DBlock(128, 128, None, 32), |
| Resnet3DBlock(128, 128, None, 32), |
| Downsample3D(128, 128, compress_time=True), |
| Resnet3DBlock(128, 256, None, 32), |
| Resnet3DBlock(256, 256, None, 32), |
| Resnet3DBlock(256, 256, None, 32), |
| Downsample3D(256, 256, compress_time=True), |
| Resnet3DBlock(256, 256, None, 32), |
| Resnet3DBlock(256, 256, None, 32), |
| Resnet3DBlock(256, 256, None, 32), |
| Downsample3D(256, 256, compress_time=False), |
| Resnet3DBlock(256, 512, None, 32), |
| Resnet3DBlock(512, 512, None, 32), |
| Resnet3DBlock(512, 512, None, 32), |
| Resnet3DBlock(512, 512, None, 32), |
| Resnet3DBlock(512, 512, None, 32), |
| ]) |
|
|
| self.norm_out = torch.nn.GroupNorm(32, 512, eps=1e-06, affine=True) |
| self.conv_act = torch.nn.SiLU() |
| self.conv_out = CachedConv3d(512, 32, kernel_size=3, stride=1, padding=(0, 1, 1)) |
|
|
|
|
| def forward(self, sample): |
| hidden_states = self.conv_in(sample) |
|
|
| for block in self.blocks: |
| hidden_states = block(hidden_states, sample) |
| |
| hidden_states = self.norm_out(hidden_states) |
| hidden_states = self.conv_act(hidden_states) |
| hidden_states = self.conv_out(hidden_states)[:, :16] |
| hidden_states = hidden_states * self.scaling_factor |
|
|
| return hidden_states |
| |
|
|
| def encode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x): |
| if tiled: |
| B, C, T, H, W = sample.shape |
| return TileWorker2Dto3D().tiled_forward( |
| forward_fn=lambda x: self.encode_small_video(x), |
| model_input=sample, |
| tile_size=(i * 8 for i in tile_size), tile_stride=(i * 8 for i in tile_stride), |
| tile_device=sample.device, tile_dtype=sample.dtype, |
| computation_device=sample.device, computation_dtype=sample.dtype, |
| scales=(16/3, (T//4+T%2)/T, 1/8, 1/8), |
| progress_bar=progress_bar |
| ) |
| else: |
| return self.encode_small_video(sample) |
| |
|
|
| def encode_small_video(self, sample): |
| B, C, T, H, W = sample.shape |
| computation_device = self.conv_in.weight.device |
| computation_dtype = self.conv_in.weight.dtype |
| value = [] |
| for i in range(T//8): |
| t = i*8 + T%2 - (T%2 and i==0) |
| t_ = i*8 + 8 + T%2 |
| model_input = sample[:, :, t: t_, :, :].to(dtype=computation_dtype, device=computation_device) |
| model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device) |
| value.append(model_output) |
| value = torch.concat(value, dim=2) |
| for name, module in self.named_modules(): |
| if isinstance(module, CachedConv3d): |
| module.clear_cache() |
| return value |
| |
|
|
| @staticmethod |
| def state_dict_converter(): |
| return CogVAEEncoderStateDictConverter() |
|
|
|
|
|
|
| class CogVAEEncoderStateDictConverter: |
| def __init__(self): |
| pass |
|
|
|
|
| def from_diffusers(self, state_dict): |
| rename_dict = { |
| "encoder.conv_in.conv.weight": "conv_in.weight", |
| "encoder.conv_in.conv.bias": "conv_in.bias", |
| "encoder.down_blocks.0.downsamplers.0.conv.weight": "blocks.3.conv.weight", |
| "encoder.down_blocks.0.downsamplers.0.conv.bias": "blocks.3.conv.bias", |
| "encoder.down_blocks.1.downsamplers.0.conv.weight": "blocks.7.conv.weight", |
| "encoder.down_blocks.1.downsamplers.0.conv.bias": "blocks.7.conv.bias", |
| "encoder.down_blocks.2.downsamplers.0.conv.weight": "blocks.11.conv.weight", |
| "encoder.down_blocks.2.downsamplers.0.conv.bias": "blocks.11.conv.bias", |
| "encoder.norm_out.weight": "norm_out.weight", |
| "encoder.norm_out.bias": "norm_out.bias", |
| "encoder.conv_out.conv.weight": "conv_out.weight", |
| "encoder.conv_out.conv.bias": "conv_out.bias", |
| } |
| prefix_dict = { |
| "encoder.down_blocks.0.resnets.0.": "blocks.0.", |
| "encoder.down_blocks.0.resnets.1.": "blocks.1.", |
| "encoder.down_blocks.0.resnets.2.": "blocks.2.", |
| "encoder.down_blocks.1.resnets.0.": "blocks.4.", |
| "encoder.down_blocks.1.resnets.1.": "blocks.5.", |
| "encoder.down_blocks.1.resnets.2.": "blocks.6.", |
| "encoder.down_blocks.2.resnets.0.": "blocks.8.", |
| "encoder.down_blocks.2.resnets.1.": "blocks.9.", |
| "encoder.down_blocks.2.resnets.2.": "blocks.10.", |
| "encoder.down_blocks.3.resnets.0.": "blocks.12.", |
| "encoder.down_blocks.3.resnets.1.": "blocks.13.", |
| "encoder.down_blocks.3.resnets.2.": "blocks.14.", |
| "encoder.mid_block.resnets.0.": "blocks.15.", |
| "encoder.mid_block.resnets.1.": "blocks.16.", |
| } |
| suffix_dict = { |
| "norm1.norm_layer.weight": "norm1.norm_layer.weight", |
| "norm1.norm_layer.bias": "norm1.norm_layer.bias", |
| "norm1.conv_y.conv.weight": "norm1.conv_y.weight", |
| "norm1.conv_y.conv.bias": "norm1.conv_y.bias", |
| "norm1.conv_b.conv.weight": "norm1.conv_b.weight", |
| "norm1.conv_b.conv.bias": "norm1.conv_b.bias", |
| "norm2.norm_layer.weight": "norm2.norm_layer.weight", |
| "norm2.norm_layer.bias": "norm2.norm_layer.bias", |
| "norm2.conv_y.conv.weight": "norm2.conv_y.weight", |
| "norm2.conv_y.conv.bias": "norm2.conv_y.bias", |
| "norm2.conv_b.conv.weight": "norm2.conv_b.weight", |
| "norm2.conv_b.conv.bias": "norm2.conv_b.bias", |
| "conv1.conv.weight": "conv1.weight", |
| "conv1.conv.bias": "conv1.bias", |
| "conv2.conv.weight": "conv2.weight", |
| "conv2.conv.bias": "conv2.bias", |
| "conv_shortcut.weight": "conv_shortcut.weight", |
| "conv_shortcut.bias": "conv_shortcut.bias", |
| "norm1.weight": "norm1.weight", |
| "norm1.bias": "norm1.bias", |
| "norm2.weight": "norm2.weight", |
| "norm2.bias": "norm2.bias", |
| } |
| state_dict_ = {} |
| for name, param in state_dict.items(): |
| if name in rename_dict: |
| state_dict_[rename_dict[name]] = param |
| else: |
| for prefix in prefix_dict: |
| if name.startswith(prefix): |
| suffix = name[len(prefix):] |
| state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param |
| return state_dict_ |
| |
|
|
| def from_civitai(self, state_dict): |
| return self.from_diffusers(state_dict) |
|
|
|
|
|
|
| class CogVAEDecoderStateDictConverter: |
| def __init__(self): |
| pass |
|
|
|
|
| def from_diffusers(self, state_dict): |
| rename_dict = { |
| "decoder.conv_in.conv.weight": "conv_in.weight", |
| "decoder.conv_in.conv.bias": "conv_in.bias", |
| "decoder.up_blocks.0.upsamplers.0.conv.weight": "blocks.6.conv.weight", |
| "decoder.up_blocks.0.upsamplers.0.conv.bias": "blocks.6.conv.bias", |
| "decoder.up_blocks.1.upsamplers.0.conv.weight": "blocks.11.conv.weight", |
| "decoder.up_blocks.1.upsamplers.0.conv.bias": "blocks.11.conv.bias", |
| "decoder.up_blocks.2.upsamplers.0.conv.weight": "blocks.16.conv.weight", |
| "decoder.up_blocks.2.upsamplers.0.conv.bias": "blocks.16.conv.bias", |
| "decoder.norm_out.norm_layer.weight": "norm_out.norm_layer.weight", |
| "decoder.norm_out.norm_layer.bias": "norm_out.norm_layer.bias", |
| "decoder.norm_out.conv_y.conv.weight": "norm_out.conv_y.weight", |
| "decoder.norm_out.conv_y.conv.bias": "norm_out.conv_y.bias", |
| "decoder.norm_out.conv_b.conv.weight": "norm_out.conv_b.weight", |
| "decoder.norm_out.conv_b.conv.bias": "norm_out.conv_b.bias", |
| "decoder.conv_out.conv.weight": "conv_out.weight", |
| "decoder.conv_out.conv.bias": "conv_out.bias" |
| } |
| prefix_dict = { |
| "decoder.mid_block.resnets.0.": "blocks.0.", |
| "decoder.mid_block.resnets.1.": "blocks.1.", |
| "decoder.up_blocks.0.resnets.0.": "blocks.2.", |
| "decoder.up_blocks.0.resnets.1.": "blocks.3.", |
| "decoder.up_blocks.0.resnets.2.": "blocks.4.", |
| "decoder.up_blocks.0.resnets.3.": "blocks.5.", |
| "decoder.up_blocks.1.resnets.0.": "blocks.7.", |
| "decoder.up_blocks.1.resnets.1.": "blocks.8.", |
| "decoder.up_blocks.1.resnets.2.": "blocks.9.", |
| "decoder.up_blocks.1.resnets.3.": "blocks.10.", |
| "decoder.up_blocks.2.resnets.0.": "blocks.12.", |
| "decoder.up_blocks.2.resnets.1.": "blocks.13.", |
| "decoder.up_blocks.2.resnets.2.": "blocks.14.", |
| "decoder.up_blocks.2.resnets.3.": "blocks.15.", |
| "decoder.up_blocks.3.resnets.0.": "blocks.17.", |
| "decoder.up_blocks.3.resnets.1.": "blocks.18.", |
| "decoder.up_blocks.3.resnets.2.": "blocks.19.", |
| "decoder.up_blocks.3.resnets.3.": "blocks.20.", |
| } |
| suffix_dict = { |
| "norm1.norm_layer.weight": "norm1.norm_layer.weight", |
| "norm1.norm_layer.bias": "norm1.norm_layer.bias", |
| "norm1.conv_y.conv.weight": "norm1.conv_y.weight", |
| "norm1.conv_y.conv.bias": "norm1.conv_y.bias", |
| "norm1.conv_b.conv.weight": "norm1.conv_b.weight", |
| "norm1.conv_b.conv.bias": "norm1.conv_b.bias", |
| "norm2.norm_layer.weight": "norm2.norm_layer.weight", |
| "norm2.norm_layer.bias": "norm2.norm_layer.bias", |
| "norm2.conv_y.conv.weight": "norm2.conv_y.weight", |
| "norm2.conv_y.conv.bias": "norm2.conv_y.bias", |
| "norm2.conv_b.conv.weight": "norm2.conv_b.weight", |
| "norm2.conv_b.conv.bias": "norm2.conv_b.bias", |
| "conv1.conv.weight": "conv1.weight", |
| "conv1.conv.bias": "conv1.bias", |
| "conv2.conv.weight": "conv2.weight", |
| "conv2.conv.bias": "conv2.bias", |
| "conv_shortcut.weight": "conv_shortcut.weight", |
| "conv_shortcut.bias": "conv_shortcut.bias", |
| } |
| state_dict_ = {} |
| for name, param in state_dict.items(): |
| if name in rename_dict: |
| state_dict_[rename_dict[name]] = param |
| else: |
| for prefix in prefix_dict: |
| if name.startswith(prefix): |
| suffix = name[len(prefix):] |
| state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param |
| return state_dict_ |
| |
|
|
| def from_civitai(self, state_dict): |
| return self.from_diffusers(state_dict) |
|
|
|
|