| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Dict, List, Literal, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torchvision.transforms |
| from einops import rearrange |
|
|
| from .siglip_vit import create_siglip_vit |
|
|
|
|
| class CLIPVisionTower(nn.Module): |
| def __init__( |
| self, |
| model_name: str = "siglip_large_patch16_384", |
| image_size: Union[Tuple[int, int], int] = 336, |
| select_feature: str = "patch", |
| select_layer: int = -2, |
| select_layers: list = None, |
| ckpt_path: str = "", |
| pixel_mean: Optional[List[float]] = None, |
| pixel_std: Optional[List[float]] = None, |
| **kwargs, |
| ): |
| super().__init__() |
|
|
| self.model_name = model_name |
| self.select_feature = select_feature |
| self.select_layer = select_layer |
| self.select_layers = select_layers |
|
|
| vision_tower_params = { |
| "model_name": model_name, |
| "image_size": image_size, |
| "ckpt_path": ckpt_path, |
| "select_layer": select_layer, |
| } |
| vision_tower_params.update(kwargs) |
| self.vision_tower, self.forward_kwargs = self.build_vision_tower( |
| vision_tower_params |
| ) |
|
|
| if pixel_mean is not None and pixel_std is not None: |
| image_norm = torchvision.transforms.Normalize( |
| mean=pixel_mean, std=pixel_std |
| ) |
| else: |
| image_norm = None |
|
|
| self.image_norm = image_norm |
|
|
| def build_vision_tower(self, vision_tower_params): |
| if self.model_name.startswith("siglip"): |
| self.select_feature = "same" |
| vision_tower = create_siglip_vit(**vision_tower_params) |
| forward_kwargs = dict() |
|
|
| elif self.model_name.startswith("sam"): |
| vision_tower = create_sam_vit(**vision_tower_params) |
| forward_kwargs = dict() |
|
|
| else: |
| from transformers import CLIPVisionModel |
|
|
| vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params) |
| forward_kwargs = dict(output_hidden_states=True) |
|
|
| return vision_tower, forward_kwargs |
|
|
| def feature_select(self, image_forward_outs): |
| if isinstance(image_forward_outs, torch.Tensor): |
| |
| image_features = image_forward_outs |
| else: |
| image_features = image_forward_outs.hidden_states[self.select_layer] |
|
|
| if self.select_feature == "patch": |
| |
| image_features = image_features[:, 1:] |
| elif self.select_feature == "cls_patch": |
| image_features = image_features |
| elif self.select_feature == "same": |
| image_features = image_features |
|
|
| else: |
| raise ValueError(f"Unexpected select feature: {self.select_feature}") |
| return image_features |
|
|
| def forward(self, images): |
| """ |
| |
| Args: |
| images (torch.Tensor): [b, 3, H, W] |
| |
| Returns: |
| image_features (torch.Tensor): [b, n_patch, d] |
| """ |
|
|
| if self.image_norm is not None: |
| images = self.image_norm(images) |
|
|
| image_forward_outs = self.vision_tower(images, **self.forward_kwargs) |
| image_features = self.feature_select(image_forward_outs) |
| return image_features |
|
|