tellurion's picture
initialize huggingface space demo
d066167
import os
import math
import numpy as np
from tqdm import tqdm
from einops import rearrange
from refnet.util import exists, append_dims
from refnet.sampling import tps_warp
from refnet.ldm.openaimodel import Timestep, zero_module
import timm
import torch
import torch.nn as nn
import torchvision.transforms
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from torch.utils.checkpoint import checkpoint
from safetensors.torch import load_file
from transformers import (
T5EncoderModel,
T5Tokenizer,
CLIPVisionModelWithProjection,
CLIPTextModel,
CLIPTokenizer,
)
versions = {
"ViT-bigG-14": "laion2b_s39b_b160k",
"ViT-H-14": "laion2b_s32b_b79k", # resblocks layers: 32
"ViT-L-14": "laion2b_s32b_b82k",
"hf-hub:apple/DFN5B-CLIP-ViT-H-14-384": None, # arch name [DFN-ViT-H]
}
hf_versions = {
"ViT-bigG-14": "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
"ViT-H-14": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
"ViT-L-14": "openai/clip-vit-large-patch14",
}
cache_dir = os.environ.get("HF_HOME", "./pretrained_models")
class WDv14SwinTransformerV2(nn.Module):
"""
WD-v14-tagger
Author: Smiling Wolf
Link: https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2
"""
negative_logit = -22
def __init__(
self,
input_size = 448,
antialias = True,
layer_idx = 0.,
load_tag = False,
logit_threshold = None,
direct_forward = False,
):
"""
Args:
input_size: Input image size
antialias: Antialias during rescaling
layer_idx: Extracted feature layer
load_tag: Set it to true if use the embedder for image classification
logit_threshold: Filtering specific channels in logits output
"""
from refnet.modules import wd_v14_swin2_tagger_config
super().__init__()
custom_config = wd_v14_swin2_tagger_config()
self.model: nn.Module = timm.create_model(
custom_config.architecture,
pretrained = False,
num_classes = custom_config.num_classes,
global_pool = custom_config.global_pool,
**custom_config.model_args
)
self.image_size = input_size
self.antialias = antialias
self.layer_idx = layer_idx
self.load_tag = load_tag
self.logit_threshold = logit_threshold
self.direct_forward = direct_forward
self.load_from_pretrained_url(load_tag)
self.get_transformer_length()
self.model.eval()
self.model.requires_grad_(False)
if self.direct_forward:
self.model.forward = self.model.forward_features.__get__(self.model, self.model.__class__)
def load_from_pretrained_url(self, load_tag=False):
import pandas as pd
from torch.hub import download_url_to_file
from data.tag_utils import load_labels, color_tag_index, geometry_tag_index
ckpt_path = os.path.join(cache_dir, "wd-v14-swin2-tagger.safetensors")
if not os.path.exists(ckpt_path):
cache_path = os.path.join(cache_dir, "weights.tmp")
download_url_to_file(
"https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2/resolve/main/model.safetensors",
dst = cache_path
)
os.rename(cache_path, ckpt_path)
if load_tag:
csv_path = hf_hub_download(
"SmilingWolf/wd-v1-4-swinv2-tagger-v2",
"selected_tags.csv",
cache_dir = cache_dir
# use_auth_token=HF_TOKEN,
)
tags_df = pd.read_csv(csv_path)
sep_tags = load_labels(tags_df)
self.tag_names = sep_tags[0]
self.rating_indexes = sep_tags[1]
self.general_indexes = sep_tags[2]
self.character_indexes = sep_tags[3]
self.color_tags = color_tag_index
self.expr_tags = geometry_tag_index
self.model.load_state_dict(load_file(ckpt_path))
def convert_labels(self, pred, general_thresh=0.25, character_thresh=0.85):
assert self.load_tag
labels = list(zip(self.tag_names, pred[0].astype(float)))
# First 4 labels are actually ratings: pick one with argmax
# ratings_names = [labels[i] for i in self.rating_indexes]
# rating = dict(ratings_names)
# Then we have general tags: pick any where prediction confidence > threshold
general_names = [labels[i] for i in self.general_indexes]
general_res = [(x[0], np.round(x[1], decimals=4)) for x in general_names if x[1] > general_thresh]
general_res = dict(general_res)
# Everything else is characters: pick any where prediction confidence > threshold
character_names = [labels[i] for i in self.character_indexes]
character_res = [x for x in character_names if x[1] > character_thresh]
character_res = dict(character_res)
sorted_general_strings = sorted(
general_res.items(),
key=lambda x: x[1],
reverse=True,
)
sorted_general_res = sorted(
general_res.items(),
key=lambda x: x[1],
reverse=True,
)
sorted_general_strings = [x[0] for x in sorted_general_strings]
sorted_general_strings = ", ".join(sorted_general_strings).replace("(", "\\(").replace(")", "\\)")
# return sorted_general_strings, rating, character_res, general_res
return sorted_general_strings + ", ".join([x[0] for x in character_res.items()]), sorted_general_res
def get_transformer_length(self):
length = 0
for stage in self.model.layers:
length += len(stage.blocks)
self.transformer_length = length
def transformer_forward(self, x):
idx = 0
x = self.model.patch_embed(x)
for stage in self.model.layers:
x = stage.downsample(x)
for blk in stage.blocks:
if idx == self.transformer_length - self.layer_idx:
return x
if not torch.jit.is_scripting():
x = checkpoint(blk, x, use_reentrant=False)
else:
x = blk(x)
idx += 1
return x
def forward(self, x, return_logits=False, pooled=True, **kwargs):
# x: [b, h, w, 3]
if self.direct_forward:
x = self.model(x)
else:
x = self.transformer_forward(x)
x = self.model.norm(x)
# x: [b, 14, 14, 1024]
if return_logits:
if pooled:
logits = self.model.forward_head(x).unsqueeze(1)
# x: [b, 1, 1024]
else:
logits = self.model.head.fc(x)
# x = F.sigmoid(x)
logits = rearrange(logits, "b h w c -> b (h w) c").contiguous()
# x: [b, 196, 9083]
# Need a threshold to cut off unnecessary classes.
if exists(self.logit_threshold) and isinstance(self.logit_threshold, float):
logits = torch.where(
logits > self.logit_threshold,
logits,
torch.ones_like(logits) * self.negative_logit
)
else:
logits = None
if pooled:
x = x.mean(dim=[1, 2]).unsqueeze(1)
else:
x = rearrange(x, "b h w c -> b (h w) c").contiguous()
return [x, logits]
def preprocess(self, x: torch.Tensor):
x = F.interpolate(
x,
(self.image_size, self.image_size),
mode = "bicubic",
align_corners = True,
antialias = self.antialias
)
# convert RGB to BGR
x = x[:, [2, 1, 0]]
return x
@torch.no_grad()
def encode(self, img: torch.Tensor, return_logits=False, pooled=True, **kwargs):
# Input image must be in RGB format
return self(self.preprocess(img), return_logits, pooled)
@torch.no_grad()
def predict_labels(self, img: torch.Tensor, *args, **kwargs):
assert len(img.shape) == 4 and img.shape[0] == 1
logits = self(self.preprocess(img), return_logits=True, pooled=True)[1]
logits = F.sigmoid(logits).detach().cpu().numpy()
return self.convert_labels(logits, *args, **kwargs)
def geometry_update(self, emb, geometry_emb, scale_factor=1):
"""
Args:
emb: WD embedding from reference image
geometry_emb: WD embedding from sketch image
"""
geometry_mask = torch.zeros_like(emb)
geometry_mask[:, :, self.expr_tags] = 1 # Only geometry channels
emb = emb * (1 - geometry_mask) + geometry_emb * geometry_mask * scale_factor
return emb
@property
def dtype(self):
return self.model.head.fc.weight.dtype
class OpenCLIP(nn.Module):
def __init__(self, vision_config=None, text_config=None, **kwargs):
super().__init__()
if exists(vision_config):
vision_config.update(kwargs)
else:
vision_config = kwargs
if exists(text_config):
text_config.update(kwargs)
else:
text_config = kwargs
self.visual = FrozenOpenCLIPImageEmbedder(**vision_config)
self.transformer = FrozenOpenCLIPEmbedder(**text_config)
def preprocess(self, x):
return self.visual.preprocess(x)
@property
def scale_factor(self):
return self.visual.scale_factor
def update_scale_factor(self, scale_factor):
self.visual.update_scale_factor(scale_factor)
def encode(self, *args, **kwargs):
return self.visual.encode(*args, **kwargs)
@torch.no_grad()
def encode_text(self, text, normalize=True):
return self.transformer(text, normalize)
def calculate_scale(self, v: torch.Tensor, t: torch.Tensor):
"""
Calculate the projection of v along the direction of t
params:
v: visual tokens from clip image encoder, shape: (b, n, c)
t: text features from clip text encoder (argmax -1), shape: (b, 1, c)
"""
return v @ t.mT
class HFCLIPVisionModel(nn.Module):
# TODO: open_clip_torch is incompatible with deepspeed ZeRO3, change to huggingface implementation in the future
def __init__(self, arch="ViT-bigG-14", image_size=224, scale_factor=1.):
super().__init__()
self.model = CLIPVisionModelWithProjection.from_pretrained(
hf_versions[arch],
cache_dir = cache_dir
)
self.image_size = image_size
self.scale_factor = scale_factor
self.register_buffer(
'mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]).view(1, -1, 1, 1), persistent=False
)
self.register_buffer(
'std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]).view(1, -1, 1, 1), persistent=False
)
self.antialias = True
self.requires_grad_(False).eval()
def preprocess(self, x):
# normalize to [0,1]
ns = int(self.image_size * self.scale_factor)
x = F.interpolate(x, (ns, ns), mode="bicubic", align_corners=True, antialias=self.antialias)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = (x - self.mean) / self.std
return x
def forward(self, x, output_type):
outputs = self.model(x).last_hidden_state
if output_type == "cls":
outputs = outputs[:, :1]
elif output_type == "local":
outputs = outputs[:, 1:]
outputs = self.model.vision_model.post_layernorm(outputs)
outputs = self.model.visual_projection(outputs)
return outputs
@torch.no_grad()
def encode(self, img, output_type="full", preprocess=True, warp_p=0., **kwargs):
img = self.preprocess(img) if preprocess else img
if warp_p > 0.:
rand = append_dims(torch.rand(img.shape[0], device=img.device, dtype=img.dtype), img.ndim)
img = torch.where(torch.Tensor(rand > warp_p), img, tps_warp(img))
return self(img, output_type)
class FrozenT5Embedder(nn.Module):
"""Uses the T5 transformer encoder for text"""
def __init__(
self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version, cache_dir=cache_dir)
self.transformer = T5EncoderModel.from_pretrained(version, cache_dir=cache_dir)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
with torch.autocast("cuda", enabled=False):
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
@torch.no_grad()
def encode(self, text):
return self(text)
class HFCLIPTextEmbedder(nn.Module):
def __init__(self, arch, freeze=True, device="cuda", max_length=77):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(
hf_versions[arch],
cache_dir = cache_dir
)
self.model = CLIPTextModel.from_pretrained(
hf_versions[arch],
cache_dir = cache_dir
)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
if isinstance(text, torch.Tensor) and text.dtype == torch.long:
# Input is already tokenized
tokens = text
else:
# Need to tokenize text input
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.model(input_ids=tokens)
z = outputs.last_hidden_state
return z
@torch.no_grad()
def encode(self, text, normalize=False):
outputs = self(text)
if normalize:
outputs = outputs / outputs.norm(dim=-1, keepdim=True)
return outputs
class ScalarEmbedder(nn.Module):
"""embeds each dimension independently and concatenates them"""
def __init__(self, embed_dim, out_dim):
super().__init__()
self.timestep = Timestep(embed_dim)
self.embed_layer = nn.Sequential(
nn.Linear(embed_dim, out_dim),
nn.SiLU(),
zero_module(nn.Linear(out_dim, out_features=out_dim))
)
def forward(self, x, dtype=torch.float32):
emb = self.timestep(x)
emb = rearrange(emb, "b d -> b 1 d")
emb = self.embed_layer(emb.to(dtype))
return emb
class TimestepEmbedding(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.timestep = Timestep(embed_dim)
def forward(self, x):
x = self.timestep(x)
return x
if __name__ == '__main__':
import PIL.Image as Image
encoder = FrozenOpenCLIPImageEmbedder(arch="DFN-ViT-H")
image = Image.open("../../miniset/origin/70717450.jpg").convert("RGB")
image = (torchvision.transforms.ToTensor()(image) - 0.5) * 2
image = image.unsqueeze(0)
print(image.shape)
feat = encoder.encode(image, "local")
print(feat.shape)