| license: apache-2.0 | |
| ```python | |
| # You can use the following code to call our trained style encoder. Hope it helps. | |
| import torchvision.transforms.functional as F | |
| from torchvision import transforms | |
| from transformers import (AutoModel, AutoProcessor, AutoTokenizer, AutoConfig, | |
| CLIPImageProcessor, CLIPVisionModelWithProjection) | |
| class SEStyleEmbedding: | |
| def __init__(self, pretrained_path: str = "xingpng/OneIG-StyleEncoder", device: str = "cuda", dtype=torch.bfloat16): | |
| self.device = torch.device(device) | |
| self.dtype = dtype | |
| self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_path) | |
| self.image_encoder.to(self.device, dtype=self.dtype) | |
| self.image_encoder.eval() | |
| self.processor = CLIPImageProcessor() | |
| def _l2_normalize(self, x): | |
| return torch.nn.functional.normalize(x, p=2, dim=-1) | |
| def get_style_embedding(self, image_path: str): | |
| image = Image.open(image_path).convert('RGB') | |
| inputs = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, dtype=self.dtype) | |
| with torch.no_grad(): | |
| outputs = self.image_encoder(inputs) | |
| image_embeds = outputs.image_embeds | |
| image_embeds_norm = self._l2_normalize(image_embeds) | |
| return image_embeds_norm |