| import hashlib |
| import os |
| import urllib |
| import warnings |
| from typing import Any, Union, List |
| from pkg_resources import packaging |
| from torch import nn |
| import torch |
| from PIL import Image |
| from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize |
|
|
| from .model_text_encoder import build_model |
| from .simple_tokenizer import SimpleTokenizer as _Tokenizer |
|
|
| try: |
| from torchvision.transforms import InterpolationMode |
| BICUBIC = InterpolationMode.BICUBIC |
| except ImportError: |
| BICUBIC = Image.BICUBIC |
|
|
| |
| _tokenizer = _Tokenizer() |
|
|
|
|
| def _convert_image_to_rgb(image): |
| return image.convert("RGB") |
|
|
|
|
| def load(): |
| model = build_model(load_from_clip = False) |
|
|
| return model |
| |
|
|
| def tokenize(texts: Union[str, List[str]], context_length: int = 77*4-60, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: |
| """ |
| Returns the tokenized representation of given input string(s) |
| |
| Parameters |
| ---------- |
| texts : Union[str, List[str]] |
| An input string or a list of input strings to tokenize |
| |
| context_length : int |
| The context length to use; all CLIP models use 77 as the context length |
| |
| truncate: bool |
| Whether to truncate the text in case its encoding is longer than the context length |
| |
| Returns |
| ------- |
| A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. |
| We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. |
| """ |
| if isinstance(texts, str): |
| texts = [texts] |
|
|
| sot_token = _tokenizer.encoder["<|startoftext|>"] |
| eot_token = _tokenizer.encoder["<|endoftext|>"] |
| all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] |
| if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): |
| result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) |
| else: |
| result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) |
|
|
| for i, tokens in enumerate(all_tokens): |
| if len(tokens) > context_length: |
| if truncate: |
| tokens = tokens[:context_length] |
| tokens[-1] = eot_token |
| else: |
| raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") |
| result[i, :len(tokens)] = torch.tensor(tokens) |
|
|
| return result |
|
|