| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from argparse import Namespace |
| |
|
| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| |
|
| | from .adaptor_registry import adaptor_registry, dict_t, state_t |
| |
|
| | from .adaptor_generic import GenericAdaptor |
| |
|
| |
|
| | class OpenCLIP_RADIO(GenericAdaptor): |
| | def __init__(self, main_config: Namespace, adaptor_config: dict_t, state: state_t): |
| | super().__init__(main_config, adaptor_config, state) |
| |
|
| | import open_clip |
| |
|
| | self.oc_model = open_clip.create_model_from_pretrained( |
| | model_name=adaptor_config['model'], |
| | pretrained=adaptor_config['pretrained'], |
| | return_transform=False, |
| | ) |
| | |
| | self.oc_model.visual = None |
| |
|
| | self.tokenizer = open_clip.get_tokenizer(model_name=adaptor_config['model']) |
| |
|
| | def encode_text(self, text, normalize: bool = False): |
| | return self.oc_model.encode_text(text, normalize=normalize) |
| |
|
| |
|
| | @adaptor_registry.register_adaptor("open_clip") |
| | def create_open_clip_adaptor(main_config: Namespace, adaptor_config: dict_t, state: state_t): |
| | return OpenCLIP_RADIO(main_config, adaptor_config, state) |
| |
|