| |
| """ |
| Tiny VQAScore Model Wrapper |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from PIL import Image |
| import numpy as np |
|
|
| class TinyVQAScore: |
| """A tiny random version of the VQAScore model.""" |
|
|
| def __init__(self, model="tiny-random", device="cpu"): |
| self.device = torch.device(device) |
| self.model = self._create_tiny_model() |
| self.model.to(self.device) |
| self.model.eval() |
|
|
| def _create_tiny_model(self): |
| class TinyCLIPT5(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.vision_encoder = nn.Sequential( |
| nn.Conv2d(3, 64, kernel_size=16, stride=16), |
| nn.AdaptiveAvgPool2d((1, 1)), |
| nn.Flatten(), |
| nn.Linear(64, 256) |
| ) |
| self.text_encoder = nn.Sequential( |
| nn.Embedding(32128, 256), |
| nn.LayerNorm(256), |
| nn.TransformerEncoderLayer(d_model=256, nhead=8, dim_feedforward=512, dropout=0.1, batch_first=True) |
| ) |
| self.multimodal_projector = nn.Sequential( |
| nn.Linear(256, 128), nn.GELU(), |
| nn.Linear(128, 64), nn.GELU(), |
| nn.Linear(64, 1) |
| ) |
| self._init_weights() |
|
|
| def _init_weights(self): |
| for module in self.modules(): |
| if isinstance(module, (nn.Linear, nn.Conv2d)): |
| nn.init.xavier_uniform_(module.weight, gain=0.1) |
| if module.bias is not None: |
| nn.init.uniform_(module.bias, -0.1, 0.1) |
| elif isinstance(module, nn.Embedding): |
| nn.init.uniform_(module.weight, -0.1, 0.1) |
|
|
| def forward(self, pixel_values, input_ids): |
| vision_features = self.vision_encoder(pixel_values) |
| text_features = self.text_encoder(input_ids) |
| text_features = text_features.mean(dim=1) |
| combined_features = vision_features + text_features |
| score = self.multimodal_projector(combined_features) |
| return score.squeeze(-1) |
|
|
| return TinyCLIPT5() |
|
|
| def score(self, image, question): |
| if isinstance(image, Image.Image): |
| image = image.resize((224, 224)) |
| image_tensor = torch.from_numpy(np.array(image)).float() |
| image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) / 255.0 |
| else: |
| image_tensor = image |
|
|
| input_ids = torch.randint(0, 32128, (1, 10)).to(self.device) |
|
|
| with torch.no_grad(): |
| score = self.model(image_tensor.to(self.device), input_ids) |
|
|
| return torch.sigmoid(score).item() |
|
|
| if __name__ == "__main__": |
| |
| model = TinyVQAScore(device="cpu") |
| test_image = Image.new('RGB', (224, 224), color='red') |
| score = model.score(test_image, "What color is this image?") |
| print(f"Test score: {score}") |
|
|