| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from transformers import AutoTokenizer, CLIPTextModel, CLIPTextConfig |
| |
|
| | |
| | class SegVol(nn.Module): |
| | def __init__(self, |
| | image_encoder, |
| | mask_decoder, |
| | prompt_encoder, |
| | clip_ckpt, |
| | roi_size, |
| | patch_size, |
| | test_mode=False, |
| | ): |
| | super().__init__() |
| | self.image_encoder = image_encoder |
| | self.mask_decoder = mask_decoder |
| | self.prompt_encoder = prompt_encoder |
| | self.text_encoder = TextEncoder(clip_ckpt) |
| | self.feat_shape = np.array(roi_size)/np.array(patch_size) |
| | self.test_mode = test_mode |
| |
|
| | def forward(self, image, text=None, boxes=None, points=None, **kwargs): |
| | bs = image.shape[0] |
| | img_shape = (image.shape[2], image.shape[3], image.shape[4]) |
| | image_embedding, _ = self.image_encoder(image) |
| | image_embedding = image_embedding.transpose(1, 2).view(bs, -1, |
| | int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2])) |
| | |
| | if self.test_mode: |
| | return self.forward_decoder(image_embedding, img_shape, text, boxes, points) |
| | |
| | |
| |
|
| | def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None): |
| | with torch.no_grad(): |
| | if boxes is not None: |
| | if len(boxes.shape) == 2: |
| | boxes = boxes[:, None, :] |
| | if text is not None: |
| | text_embedding = self.text_encoder(text) |
| | else: |
| | text_embedding = None |
| | sparse_embeddings, dense_embeddings = self.prompt_encoder( |
| | points=points, |
| | boxes=boxes, |
| | masks=None, |
| | text_embedding=text_embedding, |
| | ) |
| |
|
| | dense_pe = self.prompt_encoder.get_dense_pe() |
| | low_res_masks, _ = self.mask_decoder( |
| | image_embeddings=image_embedding, |
| | text_embedding = text_embedding, |
| | image_pe=dense_pe, |
| | sparse_prompt_embeddings=sparse_embeddings, |
| | dense_prompt_embeddings=dense_embeddings, |
| | multimask_output=False, |
| | ) |
| | logits = F.interpolate(low_res_masks, size=img_shape, mode='trilinear', align_corners=False) |
| | return logits |
| |
|
| | class TextEncoder(nn.Module): |
| | def __init__(self, clip_ckpt): |
| | super().__init__() |
| | config = CLIPTextConfig() |
| | self.clip_text_model = CLIPTextModel(config) |
| | self.tokenizer = AutoTokenizer.from_pretrained(clip_ckpt) |
| | self.dim_align = nn.Linear(512, 768) |
| | |
| | for param in self.clip_text_model.parameters(): |
| | param.requires_grad = False |
| |
|
| | def organ2tokens(self, organ_names): |
| | text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names] |
| | tokens = self.tokenizer(text_list, padding=True, return_tensors="pt") |
| | return tokens |
| | |
| | def forward(self, text): |
| | if text is None: |
| | return None |
| | if type(text) is str: |
| | text = [text] |
| | tokens = self.organ2tokens(text) |
| | clip_outputs = self.clip_text_model(**tokens) |
| | text_embedding = clip_outputs.pooler_output |
| | text_embedding = self.dim_align(text_embedding) |
| | return text_embedding |
| |
|