Zero-Shot Image Classification
Transformers
English
medical
multimodal
vision-language pre-training
chest x-ray
Instructions to use pykale/MeDSLIP with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use pykale/MeDSLIP with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("zero-shot-image-classification", model="pykale/MeDSLIP") pipe( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/hub/parrots.png", candidate_labels=["animals", "humans", "landscape"], )# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("pykale/MeDSLIP", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # modified from https://github.com/tensorflow/models/blob/master/research/slim/nets/s3dg.py | |
| from sklearn.metrics import log_loss | |
| import torch.nn as nn | |
| import torch | |
| import math | |
| import numpy as np | |
| from torch.nn.utils.rnn import pad_sequence | |
| import torch.nn.functional as F | |
| from .transformer import * | |
| import torchvision.models as models | |
| from einops import rearrange | |
| from transformers import AutoModel | |
| """ | |
| args.N | |
| args.d_model | |
| args.res_base_model | |
| args.H | |
| args.num_queries | |
| args.dropout | |
| args.attribute_set_size | |
| """ | |
| class MeDSLIP(nn.Module): | |
| def __init__(self, config, pathology_book): | |
| super(MeDSLIP, self).__init__() | |
| self.d_model = config["d_model"] | |
| with torch.no_grad(): | |
| bert_model = self._get_bert_basemodel( | |
| config["text_encoder"], freeze_layers=None | |
| ).to(pathology_book["input_ids"].device) | |
| self.pathology_book = bert_model( | |
| input_ids=pathology_book["input_ids"], | |
| attention_mask=pathology_book["attention_mask"], | |
| ) | |
| self.pathology_book = self.pathology_book.last_hidden_state[:, 0, :] | |
| self.pathology_embedding_layer = nn.Linear(768, 256) | |
| self.cl_fc_pathology = nn.Linear(256, 768) | |
| self.cl_fc_anatomy = nn.Linear(256, 768) | |
| """ visual backbone""" | |
| self.resnet_dict = { | |
| "resnet18": models.resnet18(pretrained=False), | |
| "resnet50": models.resnet50(pretrained=False), | |
| } | |
| resnet = self._get_res_basemodel(config["res_base_model"]) | |
| num_ftrs = int(resnet.fc.in_features / 2) | |
| self.res_features = nn.Sequential(*list(resnet.children())[:-3]) | |
| self.res_l1_anatomy = nn.Linear(num_ftrs, num_ftrs) | |
| self.res_l2_anatomy = nn.Linear(num_ftrs, self.d_model) | |
| self.res_l1_pathology = nn.Linear(num_ftrs, num_ftrs) | |
| self.res_l2_pathology = nn.Linear(num_ftrs, self.d_model) | |
| self.mask_generator = nn.Linear(num_ftrs, num_ftrs) | |
| ################################### | |
| """ Query Decoder""" | |
| ################################### | |
| self.H = config["H"] | |
| decoder_layer = TransformerDecoderLayer( | |
| self.d_model, config["H"], 1024, 0.1, "relu", normalize_before=True | |
| ) | |
| decoder_norm = nn.LayerNorm(self.d_model) | |
| self.decoder_anatomy = TransformerDecoder( | |
| decoder_layer, config["N"], decoder_norm, return_intermediate=False | |
| ) | |
| self.decoder_pathology = TransformerDecoder( | |
| decoder_layer, config["N"], decoder_norm, return_intermediate=False | |
| ) | |
| # Learnable Queries | |
| self.dropout_feas_anatomy = nn.Dropout(config["dropout"]) | |
| self.dropout_feas_pathology = nn.Dropout(config["dropout"]) | |
| # Attribute classifier | |
| self.classifier_anatomy = nn.Linear(self.d_model, config["attribute_set_size"]) | |
| self.classifier_pathology = nn.Linear(self.d_model, config["attribute_set_size"]) | |
| self.apply(self._init_weights) | |
| def _get_res_basemodel(self, res_model_name): | |
| try: | |
| res_model = self.resnet_dict[res_model_name] | |
| print("Image feature extractor:", res_model_name) | |
| return res_model | |
| except: | |
| raise ( | |
| "Invalid model name. Check the config file and pass one of: resnet18 or resnet50" | |
| ) | |
| def _get_bert_basemodel(self, bert_model_name, freeze_layers): | |
| try: | |
| model = AutoModel.from_pretrained(bert_model_name) | |
| print("text feature extractor:", bert_model_name) | |
| except: | |
| raise ( | |
| "Invalid model name. Check the config file and pass a BERT model from transformers lybrary" | |
| ) | |
| if freeze_layers is not None: | |
| for layer_idx in freeze_layers: | |
| for param in list(model.encoder.layer[layer_idx].parameters()): | |
| param.requires_grad = False | |
| return model | |
| def image_encoder(self, xis): | |
| # patch features | |
| """ | |
| 16 torch.Size([16, 1024, 14, 14]) | |
| torch.Size([16, 196, 1024]) | |
| torch.Size([3136, 1024]) | |
| torch.Size([16, 196, 256]) | |
| """ | |
| batch_size = xis.shape[0] | |
| res_fea = self.res_features(xis) # batch_size,feature_size,patch_num,patch_num | |
| res_fea = rearrange(res_fea, "b d n1 n2 -> b (n1 n2) d") | |
| x = rearrange(res_fea, "b n d -> (b n) d") | |
| masks = self.mask_generator(x) | |
| x_pathology = masks * x | |
| x_anatomy = (1 - masks) * x | |
| x_pathology = self.res_l1_pathology(x_pathology) | |
| x_anatomy = self.res_l1_anatomy(x_anatomy) | |
| x_pathology = F.relu(x_pathology) | |
| x_anatomy = F.relu(x_anatomy) | |
| x_pathology = self.res_l2_pathology(x_pathology) | |
| x_anatomy = self.res_l2_anatomy(x_anatomy) | |
| out_emb_pathology = rearrange(x_pathology, "(b n) d -> b n d", b=batch_size) | |
| out_emb_anatomy = rearrange(x_anatomy, "(b n) d -> b n d", b=batch_size) | |
| return out_emb_pathology, out_emb_anatomy | |
| def forward(self, images): | |
| B = images.shape[0] | |
| device = images.device | |
| """ Visual Backbone """ | |
| x, _ = self.image_encoder(images) # batch_size,patch_num,dim | |
| features = x.transpose(0, 1) # patch_num b dim | |
| query_embed = self.pathology_embedding_layer(self.pathology_book) | |
| query_embed = query_embed.unsqueeze(1).repeat(1, B, 1) | |
| features, ws = self.decoder_pathology( | |
| query_embed, | |
| features, | |
| memory_key_padding_mask=None, | |
| pos=None, | |
| query_pos=None, | |
| ) | |
| features = self.dropout_feas_pathology(features) | |
| x = self.classifier_pathology(features).transpose(0, 1) # B query Atributes | |
| return x | |
| def _init_weights(module): | |
| r"""Initialize weights like BERT - N(0.0, 0.02), bias = 0.""" | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| elif isinstance(module, nn.MultiheadAttention): | |
| module.in_proj_weight.data.normal_(mean=0.0, std=0.02) | |
| module.out_proj.weight.data.normal_(mean=0.0, std=0.02) | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=0.02) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |