Spaces:
Running on Zero
Running on Zero
| import os | |
| import math | |
| import numpy as np | |
| from tqdm import tqdm | |
| from einops import rearrange | |
| from refnet.util import exists, append_dims | |
| from refnet.sampling import tps_warp | |
| from refnet.ldm.openaimodel import Timestep, zero_module | |
| import timm | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms | |
| import torch.nn.functional as F | |
| from huggingface_hub import hf_hub_download | |
| from torch.utils.checkpoint import checkpoint | |
| from safetensors.torch import load_file | |
| from transformers import ( | |
| T5EncoderModel, | |
| T5Tokenizer, | |
| CLIPVisionModelWithProjection, | |
| CLIPTextModel, | |
| CLIPTokenizer, | |
| ) | |
| versions = { | |
| "ViT-bigG-14": "laion2b_s39b_b160k", | |
| "ViT-H-14": "laion2b_s32b_b79k", # resblocks layers: 32 | |
| "ViT-L-14": "laion2b_s32b_b82k", | |
| "hf-hub:apple/DFN5B-CLIP-ViT-H-14-384": None, # arch name [DFN-ViT-H] | |
| } | |
| hf_versions = { | |
| "ViT-bigG-14": "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", | |
| "ViT-H-14": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", | |
| "ViT-L-14": "openai/clip-vit-large-patch14", | |
| } | |
| cache_dir = os.environ.get("HF_HOME", "./pretrained_models") | |
| class WDv14SwinTransformerV2(nn.Module): | |
| """ | |
| WD-v14-tagger | |
| Author: Smiling Wolf | |
| Link: https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2 | |
| """ | |
| negative_logit = -22 | |
| def __init__( | |
| self, | |
| input_size = 448, | |
| antialias = True, | |
| layer_idx = 0., | |
| load_tag = False, | |
| logit_threshold = None, | |
| direct_forward = False, | |
| ): | |
| """ | |
| Args: | |
| input_size: Input image size | |
| antialias: Antialias during rescaling | |
| layer_idx: Extracted feature layer | |
| load_tag: Set it to true if use the embedder for image classification | |
| logit_threshold: Filtering specific channels in logits output | |
| """ | |
| from refnet.modules import wd_v14_swin2_tagger_config | |
| super().__init__() | |
| custom_config = wd_v14_swin2_tagger_config() | |
| self.model: nn.Module = timm.create_model( | |
| custom_config.architecture, | |
| pretrained = False, | |
| num_classes = custom_config.num_classes, | |
| global_pool = custom_config.global_pool, | |
| **custom_config.model_args | |
| ) | |
| self.image_size = input_size | |
| self.antialias = antialias | |
| self.layer_idx = layer_idx | |
| self.load_tag = load_tag | |
| self.logit_threshold = logit_threshold | |
| self.direct_forward = direct_forward | |
| self.load_from_pretrained_url(load_tag) | |
| self.get_transformer_length() | |
| self.model.eval() | |
| self.model.requires_grad_(False) | |
| if self.direct_forward: | |
| self.model.forward = self.model.forward_features.__get__(self.model, self.model.__class__) | |
| def load_from_pretrained_url(self, load_tag=False): | |
| import pandas as pd | |
| from torch.hub import download_url_to_file | |
| from data.tag_utils import load_labels, color_tag_index, geometry_tag_index | |
| ckpt_path = os.path.join(cache_dir, "wd-v14-swin2-tagger.safetensors") | |
| if not os.path.exists(ckpt_path): | |
| cache_path = os.path.join(cache_dir, "weights.tmp") | |
| download_url_to_file( | |
| "https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2/resolve/main/model.safetensors", | |
| dst = cache_path | |
| ) | |
| os.rename(cache_path, ckpt_path) | |
| if load_tag: | |
| csv_path = hf_hub_download( | |
| "SmilingWolf/wd-v1-4-swinv2-tagger-v2", | |
| "selected_tags.csv", | |
| cache_dir = cache_dir | |
| # use_auth_token=HF_TOKEN, | |
| ) | |
| tags_df = pd.read_csv(csv_path) | |
| sep_tags = load_labels(tags_df) | |
| self.tag_names = sep_tags[0] | |
| self.rating_indexes = sep_tags[1] | |
| self.general_indexes = sep_tags[2] | |
| self.character_indexes = sep_tags[3] | |
| self.color_tags = color_tag_index | |
| self.expr_tags = geometry_tag_index | |
| self.model.load_state_dict(load_file(ckpt_path)) | |
| def convert_labels(self, pred, general_thresh=0.25, character_thresh=0.85): | |
| assert self.load_tag | |
| labels = list(zip(self.tag_names, pred[0].astype(float))) | |
| # First 4 labels are actually ratings: pick one with argmax | |
| # ratings_names = [labels[i] for i in self.rating_indexes] | |
| # rating = dict(ratings_names) | |
| # Then we have general tags: pick any where prediction confidence > threshold | |
| general_names = [labels[i] for i in self.general_indexes] | |
| general_res = [(x[0], np.round(x[1], decimals=4)) for x in general_names if x[1] > general_thresh] | |
| general_res = dict(general_res) | |
| # Everything else is characters: pick any where prediction confidence > threshold | |
| character_names = [labels[i] for i in self.character_indexes] | |
| character_res = [x for x in character_names if x[1] > character_thresh] | |
| character_res = dict(character_res) | |
| sorted_general_strings = sorted( | |
| general_res.items(), | |
| key=lambda x: x[1], | |
| reverse=True, | |
| ) | |
| sorted_general_res = sorted( | |
| general_res.items(), | |
| key=lambda x: x[1], | |
| reverse=True, | |
| ) | |
| sorted_general_strings = [x[0] for x in sorted_general_strings] | |
| sorted_general_strings = ", ".join(sorted_general_strings).replace("(", "\\(").replace(")", "\\)") | |
| # return sorted_general_strings, rating, character_res, general_res | |
| return sorted_general_strings + ", ".join([x[0] for x in character_res.items()]), sorted_general_res | |
| def get_transformer_length(self): | |
| length = 0 | |
| for stage in self.model.layers: | |
| length += len(stage.blocks) | |
| self.transformer_length = length | |
| def transformer_forward(self, x): | |
| idx = 0 | |
| x = self.model.patch_embed(x) | |
| for stage in self.model.layers: | |
| x = stage.downsample(x) | |
| for blk in stage.blocks: | |
| if idx == self.transformer_length - self.layer_idx: | |
| return x | |
| if not torch.jit.is_scripting(): | |
| x = checkpoint(blk, x, use_reentrant=False) | |
| else: | |
| x = blk(x) | |
| idx += 1 | |
| return x | |
| def forward(self, x, return_logits=False, pooled=True, **kwargs): | |
| # x: [b, h, w, 3] | |
| if self.direct_forward: | |
| x = self.model(x) | |
| else: | |
| x = self.transformer_forward(x) | |
| x = self.model.norm(x) | |
| # x: [b, 14, 14, 1024] | |
| if return_logits: | |
| if pooled: | |
| logits = self.model.forward_head(x).unsqueeze(1) | |
| # x: [b, 1, 1024] | |
| else: | |
| logits = self.model.head.fc(x) | |
| # x = F.sigmoid(x) | |
| logits = rearrange(logits, "b h w c -> b (h w) c").contiguous() | |
| # x: [b, 196, 9083] | |
| # Need a threshold to cut off unnecessary classes. | |
| if exists(self.logit_threshold) and isinstance(self.logit_threshold, float): | |
| logits = torch.where( | |
| logits > self.logit_threshold, | |
| logits, | |
| torch.ones_like(logits) * self.negative_logit | |
| ) | |
| else: | |
| logits = None | |
| if pooled: | |
| x = x.mean(dim=[1, 2]).unsqueeze(1) | |
| else: | |
| x = rearrange(x, "b h w c -> b (h w) c").contiguous() | |
| return [x, logits] | |
| def preprocess(self, x: torch.Tensor): | |
| x = F.interpolate( | |
| x, | |
| (self.image_size, self.image_size), | |
| mode = "bicubic", | |
| align_corners = True, | |
| antialias = self.antialias | |
| ) | |
| # convert RGB to BGR | |
| x = x[:, [2, 1, 0]] | |
| return x | |
| def encode(self, img: torch.Tensor, return_logits=False, pooled=True, **kwargs): | |
| # Input image must be in RGB format | |
| return self(self.preprocess(img), return_logits, pooled) | |
| def predict_labels(self, img: torch.Tensor, *args, **kwargs): | |
| assert len(img.shape) == 4 and img.shape[0] == 1 | |
| logits = self(self.preprocess(img), return_logits=True, pooled=True)[1] | |
| logits = F.sigmoid(logits).detach().cpu().numpy() | |
| return self.convert_labels(logits, *args, **kwargs) | |
| def geometry_update(self, emb, geometry_emb, scale_factor=1): | |
| """ | |
| Args: | |
| emb: WD embedding from reference image | |
| geometry_emb: WD embedding from sketch image | |
| """ | |
| geometry_mask = torch.zeros_like(emb) | |
| geometry_mask[:, :, self.expr_tags] = 1 # Only geometry channels | |
| emb = emb * (1 - geometry_mask) + geometry_emb * geometry_mask * scale_factor | |
| return emb | |
| def dtype(self): | |
| return self.model.head.fc.weight.dtype | |
| class OpenCLIP(nn.Module): | |
| def __init__(self, vision_config=None, text_config=None, **kwargs): | |
| super().__init__() | |
| if exists(vision_config): | |
| vision_config.update(kwargs) | |
| else: | |
| vision_config = kwargs | |
| if exists(text_config): | |
| text_config.update(kwargs) | |
| else: | |
| text_config = kwargs | |
| self.visual = FrozenOpenCLIPImageEmbedder(**vision_config) | |
| self.transformer = FrozenOpenCLIPEmbedder(**text_config) | |
| def preprocess(self, x): | |
| return self.visual.preprocess(x) | |
| def scale_factor(self): | |
| return self.visual.scale_factor | |
| def update_scale_factor(self, scale_factor): | |
| self.visual.update_scale_factor(scale_factor) | |
| def encode(self, *args, **kwargs): | |
| return self.visual.encode(*args, **kwargs) | |
| def encode_text(self, text, normalize=True): | |
| return self.transformer(text, normalize) | |
| def calculate_scale(self, v: torch.Tensor, t: torch.Tensor): | |
| """ | |
| Calculate the projection of v along the direction of t | |
| params: | |
| v: visual tokens from clip image encoder, shape: (b, n, c) | |
| t: text features from clip text encoder (argmax -1), shape: (b, 1, c) | |
| """ | |
| return v @ t.mT | |
| class HFCLIPVisionModel(nn.Module): | |
| # TODO: open_clip_torch is incompatible with deepspeed ZeRO3, change to huggingface implementation in the future | |
| def __init__(self, arch="ViT-bigG-14", image_size=224, scale_factor=1.): | |
| super().__init__() | |
| self.model = CLIPVisionModelWithProjection.from_pretrained( | |
| hf_versions[arch], | |
| cache_dir = cache_dir | |
| ) | |
| self.image_size = image_size | |
| self.scale_factor = scale_factor | |
| self.register_buffer( | |
| 'mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]).view(1, -1, 1, 1), persistent=False | |
| ) | |
| self.register_buffer( | |
| 'std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]).view(1, -1, 1, 1), persistent=False | |
| ) | |
| self.antialias = True | |
| self.requires_grad_(False).eval() | |
| def preprocess(self, x): | |
| # normalize to [0,1] | |
| ns = int(self.image_size * self.scale_factor) | |
| x = F.interpolate(x, (ns, ns), mode="bicubic", align_corners=True, antialias=self.antialias) | |
| x = (x + 1.0) / 2.0 | |
| # renormalize according to clip | |
| x = (x - self.mean) / self.std | |
| return x | |
| def forward(self, x, output_type): | |
| outputs = self.model(x).last_hidden_state | |
| if output_type == "cls": | |
| outputs = outputs[:, :1] | |
| elif output_type == "local": | |
| outputs = outputs[:, 1:] | |
| outputs = self.model.vision_model.post_layernorm(outputs) | |
| outputs = self.model.visual_projection(outputs) | |
| return outputs | |
| def encode(self, img, output_type="full", preprocess=True, warp_p=0., **kwargs): | |
| img = self.preprocess(img) if preprocess else img | |
| if warp_p > 0.: | |
| rand = append_dims(torch.rand(img.shape[0], device=img.device, dtype=img.dtype), img.ndim) | |
| img = torch.where(torch.Tensor(rand > warp_p), img, tps_warp(img)) | |
| return self(img, output_type) | |
| class FrozenT5Embedder(nn.Module): | |
| """Uses the T5 transformer encoder for text""" | |
| def __init__( | |
| self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True | |
| ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl | |
| super().__init__() | |
| self.tokenizer = T5Tokenizer.from_pretrained(version, cache_dir=cache_dir) | |
| self.transformer = T5EncoderModel.from_pretrained(version, cache_dir=cache_dir) | |
| self.device = device | |
| self.max_length = max_length | |
| if freeze: | |
| self.freeze() | |
| def freeze(self): | |
| self.transformer = self.transformer.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, text): | |
| batch_encoding = self.tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_length=True, | |
| return_overflowing_tokens=False, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| tokens = batch_encoding["input_ids"].to(self.device) | |
| with torch.autocast("cuda", enabled=False): | |
| outputs = self.transformer(input_ids=tokens) | |
| z = outputs.last_hidden_state | |
| return z | |
| def encode(self, text): | |
| return self(text) | |
| class HFCLIPTextEmbedder(nn.Module): | |
| def __init__(self, arch, freeze=True, device="cuda", max_length=77): | |
| super().__init__() | |
| self.tokenizer = CLIPTokenizer.from_pretrained( | |
| hf_versions[arch], | |
| cache_dir = cache_dir | |
| ) | |
| self.model = CLIPTextModel.from_pretrained( | |
| hf_versions[arch], | |
| cache_dir = cache_dir | |
| ) | |
| self.device = device | |
| self.max_length = max_length | |
| if freeze: | |
| self.freeze() | |
| def freeze(self): | |
| self.model = self.model.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, text): | |
| if isinstance(text, torch.Tensor) and text.dtype == torch.long: | |
| # Input is already tokenized | |
| tokens = text | |
| else: | |
| # Need to tokenize text input | |
| batch_encoding = self.tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| tokens = batch_encoding["input_ids"].to(self.device) | |
| outputs = self.model(input_ids=tokens) | |
| z = outputs.last_hidden_state | |
| return z | |
| def encode(self, text, normalize=False): | |
| outputs = self(text) | |
| if normalize: | |
| outputs = outputs / outputs.norm(dim=-1, keepdim=True) | |
| return outputs | |
| class ScalarEmbedder(nn.Module): | |
| """embeds each dimension independently and concatenates them""" | |
| def __init__(self, embed_dim, out_dim): | |
| super().__init__() | |
| self.timestep = Timestep(embed_dim) | |
| self.embed_layer = nn.Sequential( | |
| nn.Linear(embed_dim, out_dim), | |
| nn.SiLU(), | |
| zero_module(nn.Linear(out_dim, out_features=out_dim)) | |
| ) | |
| def forward(self, x, dtype=torch.float32): | |
| emb = self.timestep(x) | |
| emb = rearrange(emb, "b d -> b 1 d") | |
| emb = self.embed_layer(emb.to(dtype)) | |
| return emb | |
| class TimestepEmbedding(nn.Module): | |
| def __init__(self, embed_dim): | |
| super().__init__() | |
| self.timestep = Timestep(embed_dim) | |
| def forward(self, x): | |
| x = self.timestep(x) | |
| return x | |
| if __name__ == '__main__': | |
| import PIL.Image as Image | |
| encoder = FrozenOpenCLIPImageEmbedder(arch="DFN-ViT-H") | |
| image = Image.open("../../miniset/origin/70717450.jpg").convert("RGB") | |
| image = (torchvision.transforms.ToTensor()(image) - 0.5) * 2 | |
| image = image.unsqueeze(0) | |
| print(image.shape) | |
| feat = encoder.encode(image, "local") | |
| print(feat.shape) |