| import math |
| import pickle |
| import sys |
|
|
| import torch |
| from torch import nn, optim |
| from torch.utils.data import DataLoader |
| import torch.nn.functional as F |
|
|
| import lightning as L |
| from lightning.pytorch.tuner import Tuner |
| from lightning.pytorch.callbacks import LearningRateMonitor |
| import wandb |
| from pytorch_lightning.loggers import WandbLogger |
| from transformers import CLIPTokenizer, CLIPTextModelWithProjection |
|
|
|
|
|
|
| class SoftAttention(L.LightningModule): |
| def __init__(self, learning_rate=0.001, batch_size=10, unfreeze=0, random_text=False, random_everything=False, |
| fixed_text=False, random_images=False): |
| super(SoftAttention, self).__init__() |
| self.my_optimizer = None |
| self.my_scheduler = None |
| self.save_hyperparameters() |
| self.learning_rate = learning_rate |
| self.batch_size = batch_size |
| self.frozen = False |
| self.unfreeze_epoch = unfreeze |
| self.loss_method = torch.nn.CrossEntropyLoss() |
| self.train_sum_precision = 0 |
| self.train_sum_accuracy = 0 |
| self.train_sum_recall = 0 |
| self.train_sum_runs = 0 |
| self.val_sum_precision = 0 |
| self.val_sum_accuracy = 0 |
| self.val_sum_recall = 0 |
| self.val_sum_runs = 0 |
|
|
| |
| |
| self.text_reduction = torch.nn.Linear(512, 256) |
| self.image_reduction = torch.nn.Linear(512, 256) |
|
|
| |
| self.W_query_text_half_dim = torch.nn.Linear(256, 256) |
| self.W_query_image_half_dim = torch.nn.Linear(256, 256) |
| self.W_query_text_full_dim = torch.nn.Linear(512, 512) |
| self.W_query_image_full_dim = torch.nn.Linear(512, 512) |
|
|
| self.W_key_text_half_dim = torch.nn.Linear(256, 256) |
| self.W_key_image_half_dim = torch.nn.Linear(256, 256) |
| self.W_key_image_full_dim = torch.nn.Linear(512, 512) |
| self.W_key_text_full_dim = torch.nn.Linear(512, 512) |
|
|
| |
| self.fixed_text = torch.tensor([2.2875e-01, 2.3762e-02, 1.3448e-01, 6.5997e-02, 2.5605e-01, |
| -1.6183e-01, 7.1169e-03, -1.6895e+00, 1.8110e-01, 1.7249e-01, |
| 7.0582e-02, -6.3566e-02, -1.5862e-01, -2.3586e-01, 6.9382e-02, |
| 9.4649e-02, 6.3127e-01, -4.1287e-02, -4.9883e-02, -2.1821e-01, |
| 5.8677e-01, -2.5353e-01, 1.4792e-01, 2.2195e-02, -6.8436e-02, |
| -1.5512e-01, -9.8894e-02, 6.3377e-02, -2.3078e-01, 9.3588e-02, |
| 5.2875e-02, -5.1388e-01, -7.0461e-02, 2.4253e-02, -7.8069e-02, |
| 7.6921e-02, -1.1610e-01, -1.3345e-01, 7.8038e-03, -2.0226e-01, |
| 1.1381e-01, -9.6335e-02, -2.2195e-02, -6.5028e-02, 1.4025e-01, |
| 2.6969e-01, -1.0758e-01, 3.6736e-02, 3.2893e-01, -1.9067e-01, |
| 4.9070e-02, 8.0207e-02, 7.2942e-02, 7.7496e-03, 2.0883e-01, |
| 1.7339e-01, 1.0072e-01, -1.7874e-01, -4.6898e-02, -6.2682e-02, |
| 5.9596e-02, 5.2925e-02, 2.4633e-01, -7.2811e-02, -1.4157e-01, |
| 8.8013e-03, -4.6815e-02, -7.4260e-02, 8.6530e-03, -1.8174e-01, |
| 1.6101e-01, -4.8832e-02, -5.8030e-02, -3.2518e-02, -6.2896e-02, |
| -2.3472e-01, -8.0996e-02, 1.1261e-01, -2.1039e-01, -2.3837e-01, |
| -2.6827e-02, -2.3075e-01, -2.2087e-02, 5.4009e-01, 3.7671e-02, |
| 3.3140e-01, -4.2569e-02, -1.6946e-01, 1.7165e-01, 3.0887e-01, |
| 4.9847e-02, 1.2438e-02, -2.0701e+00, 2.7104e-01, 1.9001e-01, |
| 3.1907e-01, -9.1116e-02, -8.3141e-02, 4.5765e-03, -2.5675e-01, |
| -2.2119e-02, 3.4949e-02, 2.8192e-01, 7.9688e-02, -2.1810e-01, |
| 8.1565e-02, 3.3208e-01, -9.1857e-02, -2.1145e-01, -1.6843e-01, |
| 6.7942e-02, 5.1067e-01, -1.6835e-01, 2.2090e-02, 1.8061e-02, |
| -2.1313e-01, 2.6867e-02, -2.2734e-01, 8.4164e-02, -4.7868e-02, |
| 2.0980e-02, -2.1424e-01, -2.2919e-02, 1.7554e-01, 5.2253e-02, |
| -2.2049e-01, 6.9408e-02, 7.0811e-02, -1.1892e-02, -4.7958e-02, |
| 7.9476e-02, 1.8851e-01, 2.2516e-02, 8.6119e+00, -7.8583e-02, |
| 1.0218e-01, 1.6675e-01, -4.0961e-01, 4.5291e-02, 7.9783e-02, |
| -1.1764e-01, -2.3162e-01, -2.7717e-02, 1.2963e-01, -3.0165e-01, |
| -2.1588e-02, -1.2324e-01, 1.9732e-02, -1.9312e-01, -7.1229e-02, |
| 2.5102e-01, -4.1674e-01, -1.5610e-01, -6.1321e-03, -4.5332e-02, |
| 6.1500e-02, -1.5942e-01, 3.5142e-01, -2.1119e-01, 4.5057e-02, |
| -5.6277e-02, -3.4298e-01, -1.6499e-01, -2.9384e-02, -2.7163e-01, |
| 6.5339e-03, 2.7674e-02, -1.1302e-01, -2.6373e-02, -1.4370e-01, |
| 2.1936e-01, 1.3103e-01, 2.5538e-01, 1.9502e-01, -1.5278e-01, |
| 1.4978e-01, -2.5552e-01, 2.2397e-01, -1.0369e-01, -1.0491e-01, |
| 5.1112e-01, 2.4879e-01, 7.0940e-02, 1.7351e-01, -3.6831e-02, |
| 1.5027e-01, -1.9452e-01, 2.0322e-01, 8.5931e-02, -2.8588e-03, |
| 3.1146e-02, -3.3307e-01, 1.1595e-01, 1.9435e-01, -3.4536e-02, |
| 2.5245e-01, 4.5388e-02, 2.1197e-02, 4.2232e-02, 4.2436e-02, |
| 4.9622e-02, -2.0907e-01, 1.2264e-01, -7.3529e-02, -2.1788e-01, |
| -1.2429e-01, -8.1422e-02, 1.6572e-01, -6.0989e-02, 8.0322e-02, |
| 3.3477e-01, -7.2207e-02, -8.8658e-02, -2.4944e-01, 9.9211e-02, |
| 8.6244e-02, 8.8807e-02, -1.9676e-01, -4.5365e-03, -3.7754e-01, |
| -1.7204e-01, -1.3001e-01, 6.4961e-02, -5.8192e-03, 2.4670e-01, |
| -8.3591e-02, -3.0810e-01, -3.4549e-02, -1.4452e-01, -5.5416e-02, |
| 1.0527e-02, 3.1159e-01, -1.3857e-01, -2.2676e-01, 1.4768e-01, |
| 3.2650e-01, 2.3971e-01, 6.8196e-02, -2.6235e-02, -2.9741e-01, |
| 4.7721e-02, -1.2859e-02, 2.0340e-01, 1.7823e-02, -1.1337e-01, |
| 4.4077e-02, -1.3949e-01, 2.9229e-01, 1.7425e-01, -5.0722e-03, |
| -6.3722e-02, 1.0181e-01, 2.3344e-02, 2.2200e-01, 3.5022e-02, |
| 1.5361e-01, -1.0702e-03, 2.9319e-02, 1.8938e-01, -7.2263e-02, |
| 2.2192e-02, 9.5394e-02, -4.4459e-03, 7.6698e-02, -1.7830e-01, |
| 1.0213e-01, -8.8493e-02, -1.6439e-01, -1.1085e-01, 1.2938e-01, |
| 2.3929e-01, -4.9047e-02, -1.2814e-01, -2.1075e-01, 2.4423e-01, |
| -4.4565e-02, -5.1225e-02, -4.0214e-02, -1.4033e-01, 6.3284e-02, |
| 4.7094e-01, -2.6821e-02, 2.1138e-02, 1.1590e-01, -2.0023e-02, |
| 1.7200e-01, 3.8215e-01, -2.4871e-01, -1.5359e-01, 2.4691e-01, |
| 1.4904e-01, -1.0636e-01, 2.4185e-01, 1.7119e-03, 1.4618e-01, |
| -1.6813e-01, -4.4372e-01, -1.7475e-01, -6.9891e-02, -4.5553e-02, |
| 9.3102e-02, 1.7686e-02, -1.1781e-01, 6.9423e-02, 1.0211e-02, |
| 3.2742e-01, 7.5272e-02, 8.5080e-02, -1.7731e-01, 1.4030e-01, |
| 2.7764e-01, -6.5041e-02, 8.5968e+00, 2.5900e-01, -2.0825e-01, |
| 9.6241e-02, -1.5257e-01, -3.4269e-01, -1.1251e-01, 3.0549e-01, |
| 3.1628e-01, 6.1856e-01, 1.5791e-03, 6.5656e-02, 1.8862e-02, |
| -7.1927e-02, 1.3239e-01, -1.1126e-01, 1.1135e-02, -3.2411e+00, |
| -4.7349e-02, 1.4775e-01, -9.7712e-02, 4.5727e-02, -1.3868e-01, |
| 2.1260e-01, 1.5465e-01, 1.1308e-01, -8.0110e-02, -1.3123e-01, |
| 1.8527e-01, -8.6424e-02, -1.9778e-01, -1.3295e-01, -1.5880e-01, |
| 2.0800e-01, -3.6513e-02, 2.6472e-02, 2.7275e-01, 1.8995e-01, |
| -7.7340e-02, 1.2059e-02, 3.5163e-02, 1.5442e-02, 5.1417e-02, |
| 5.0993e-01, 1.2994e-01, 2.3873e-01, -7.2816e-02, 1.5850e-01, |
| -2.0404e-01, -2.2941e-01, 2.3660e-01, 2.0418e-01, 6.7775e-02, |
| -3.9195e-01, 3.6655e-01, 1.6498e-01, 6.4065e-02, 4.9579e-02, |
| 2.8265e-01, -5.9919e-03, 4.0163e-02, 8.9072e-02, 1.5125e-01, |
| 9.0711e-02, -1.2608e-01, -1.0413e-01, -2.1931e-01, 5.0183e-02, |
| -3.4841e-02, -8.1449e-02, -1.1225e-01, -4.5787e-02, -7.8871e-02, |
| 3.8858e-02, 9.2660e-02, 1.5991e-01, -6.7528e-02, -6.3166e-02, |
| -4.7824e-03, -1.3528e-01, 1.4845e-01, 2.0460e-01, -9.3238e-02, |
| 1.4902e-03, 1.1896e-01, -3.1337e-01, 2.1637e-02, 1.4990e-01, |
| -2.1179e-03, -8.1374e-02, -1.0241e-01, -8.0754e-02, -1.4449e-01, |
| -1.3549e-01, -7.5588e-02, -8.0083e-02, -1.4114e-01, 2.9467e-03, |
| 3.5340e-01, -4.3351e-02, 9.6934e-02, 1.3625e-01, 1.3339e-01, |
| -1.2059e-02, -1.4325e-01, -2.1202e-01, 3.8758e-02, 2.5965e-01, |
| -7.8454e-02, 1.5983e-01, 1.0115e-02, 2.2192e-01, -1.4043e-01, |
| 6.7966e-02, -1.4672e-01, -1.8846e-01, 1.9488e-01, 1.2942e-01, |
| -1.3165e-02, -1.6099e-01, -9.6146e-02, 1.3439e-01, -5.0560e-02, |
| 8.2779e-02, -2.4827e-01, -7.8047e-02, -3.1163e-01, -1.7481e-01, |
| 2.1450e-01, -7.6112e-02, -1.9967e-02, 5.7099e-02, 7.7664e-02, |
| -7.9647e-02, 3.3941e-02, 2.9551e-02, 1.4231e-01, 2.3480e-02, |
| 1.5209e-01, -2.0011e-01, 1.1153e-01, 1.2694e-01, 8.7853e-02, |
| 2.6997e-01, 1.3525e-01, 1.9541e-01, 3.4429e-03, -9.6446e-02, |
| 7.6708e-02, -3.0698e-02, -1.8507e-01, 2.5645e-01, 2.8182e-01, |
| -1.2282e-01, -1.1017e-01, 2.2249e-01, 2.1966e-01, 3.5795e-01, |
| 1.6279e-01, 1.7276e-01, 2.1410e-01, -3.2499e-01, 5.0327e-02, |
| 7.9813e-02, -1.5915e-01, -3.6175e-02, 1.4376e-01, 2.9565e-01, |
| 6.9097e-02, -8.0661e-01, 4.9966e-02, 6.2506e-02, 1.8852e-02, |
| -8.6921e-02, 6.0899e-02, 2.2442e-01, -1.4272e-01, -4.0656e-04, |
| -1.2531e-01, 1.5240e-01, -6.8841e-02, 4.2114e-01, -4.4379e-02, |
| -3.5105e-02, 1.4931e-01, -8.3358e-02, -1.0498e-01, 1.4575e-01, |
| -1.6491e-01, 4.7820e-02, 2.5958e-01, 1.1974e-01, 1.8271e-01, |
| 1.7439e-02, -1.5855e-01, -9.0135e-02, -2.6199e-01, -2.5709e-01, |
| 6.3203e-03, 7.5823e-02]) |
|
|
| self.random_text_flag = random_text |
| self.random_everything_flag = random_everything |
| self.fixed_text_flag = fixed_text |
| self.random_image_flag = random_images |
|
|
| |
| self.W_query = { |
| "multimodal": [self.text_reduction, self.image_reduction, self.W_query_text_half_dim, |
| self.W_query_image_half_dim], |
| "image": [self.W_query_image_full_dim], |
| } |
|
|
| self.W_key = { |
| "multimodal": [self.text_reduction, self.image_reduction, self.W_key_text_half_dim, |
| self.W_key_image_half_dim], |
| "image": [self.W_key_image_full_dim] |
| } |
|
|
| def weight_pass(self, query_text, query_image, key_text, key_image): |
| inference_functions = [ |
| (True, True, True, True), |
| (False, True, False, True), |
| (False, True, True, True) |
| ] |
|
|
| if None in (query_image, key_image): |
| raise ValueError("Query and Key image cannot be None") |
|
|
| if (query_text is not None, query_image is not None, key_text is not None, |
| key_image is not None) in inference_functions: |
| query = self._queries_inference(query_text, query_image) |
| key = self._keys_inference(key_text, key_image) |
| return query, key |
| else: |
| raise ValueError("Invalid input") |
|
|
| def _queries_inference(self, query_text, query_image): |
| if query_text is None: |
| output = self.W_query_image_full_dim(query_image) |
| elif query_image is None: |
| raise ValueError("Query image cannot be None") |
| else: |
| text_reduction = self.text_reduction(query_text) |
| image_reduction = self.image_reduction(query_image) |
| query_text_half_dim = self.W_query_text_half_dim(text_reduction) |
| query_image_half_dim = self.W_query_image_half_dim(image_reduction) |
| output = torch.cat((query_text_half_dim, query_image_half_dim), dim=-1) |
| return output |
|
|
| def _keys_inference(self, key_text, key_image): |
| if key_text is None: |
| output = self.W_key_image_full_dim(key_image) |
| elif key_image is None: |
| raise ValueError("Key image cannot be None") |
| else: |
| text_reduction = self.text_reduction(key_text) |
| image_reduction = self.image_reduction(key_image) |
| key_text_half_dim = self.W_key_text_half_dim(text_reduction) |
| key_image_half_dim = self.W_key_image_half_dim(image_reduction) |
| output = torch.cat((key_text_half_dim, key_image_half_dim), dim=-1) |
| return output |
|
|
| def forward(self, query_text, query_image, key_text, key_image): |
|
|
| query_text = query_text.to(self.device) |
| query_image = query_image.to(self.device) |
| key_text = key_text.to(self.device) |
| key_image = key_image.to(self.device) |
| query, key = self.weight_pass(query_text, query_image, key_text, key_image) |
|
|
| d_k = key.size()[-1] |
|
|
| key_transposed = key.transpose(1, 2) |
| logits = torch.matmul(query, key_transposed) / math.sqrt(d_k) |
| logits = logits.squeeze() |
|
|
| if len(logits.shape) <= 2: |
| softmax = F.softmax(logits, dim=0) |
| else: |
| softmax = F.softmax(logits, dim=1) |
|
|
|
|
| return softmax, logits |
|
|
| def training_step(self, train_batch, batch_idx): |
|
|
| if self.current_epoch == 0 and not self.frozen and self.unfreeze_epoch != 0: |
| print("Freezing....................................................") |
| for param in self.image_reduction.parameters(): |
| param.requires_grad = False |
| self.frozen = True |
|
|
| if self.current_epoch == self.unfreeze_epoch and self.frozen: |
| print("Unfreezing....................................................") |
| for param in self.image_reduction.parameters(): |
| param.requires_grad = True |
| self.frozen = False |
|
|
| |
| queries = train_batch['queries'] |
| keys = train_batch['keys'] |
| real_labels = train_batch['real_index'] |
|
|
| keys_text = [] |
| keys_image = [] |
| for batch in keys: |
| temp_key_text = [] |
| temp_key_image = [] |
|
|
| for key_text, key_image in batch: |
| temp_key_text.append(key_text) |
| temp_key_image.append(key_image) |
| keys_text.append(torch.stack(temp_key_text)) |
| keys_image.append(torch.stack(temp_key_image)) |
|
|
| queries_text = [] |
| queries_image = [] |
| for batch in queries: |
| temp_query_text = [] |
| temp_query_image = [] |
| for query_text, query_image in batch: |
| temp_query_text.append(query_text) |
| temp_query_image.append(query_image) |
| queries_text.append(torch.stack(temp_query_text)) |
| queries_image.append(torch.stack(temp_query_image)) |
|
|
| queries_text = torch.stack(queries_text) |
| queries_image = torch.stack(queries_image) |
| keys_text = torch.stack(keys_text) |
| keys_image = torch.stack(keys_image) |
|
|
| if self.fixed_text_flag: |
| print("Fixed text flag") |
| queries_text_shape = queries_text.shape |
| keys_text_shape = keys_text.shape |
| queries_text = self.fixed_text.expand(*queries_text_shape).to(queries_text.device) |
| keys_text = self.fixed_text.expand(*keys_text_shape).to(keys_text.device) |
|
|
| if self.random_text_flag: |
| print("Random text flag") |
| old_queries_text = queries_text.clone() |
| old_keys_text = keys_text.clone() |
| queries_text = torch.randn(queries_text.shape).to(queries_text.device) |
| keys_text = torch.randn(keys_text.shape).to(keys_text.device) |
| if torch.equal(queries_text, old_queries_text): |
| print("Queries text are equal") |
| if torch.equal(keys_text, old_keys_text): |
| print("Keys text are equal") |
|
|
| if self.random_image_flag: |
| print("Random image flag") |
| old_queries_image = queries_image.clone() |
| old_keys_image = keys_image.clone() |
| queries_image = torch.randn(queries_image.shape).to(queries_image.device) |
| keys_image = torch.randn(keys_image.shape).to(keys_image.device) |
| if torch.equal(queries_image, old_queries_image): |
| print("Queries image are equal") |
| if torch.equal(keys_image, old_keys_image): |
| print("Keys image are equal") |
|
|
| if self.random_everything_flag: |
| print("Random everything flag") |
| old_queries_text = queries_text.clone() |
| old_keys_text = keys_text.clone() |
| old_queries_image = queries_image.clone() |
| old_keys_image = keys_image.clone() |
| queries_text = torch.randn(queries_text.shape).to(queries_text.device) |
| keys_text = torch.randn(keys_text.shape).to(keys_text.device) |
| queries_image = torch.randn(queries_image.shape).to(queries_image.device) |
| keys_image = torch.randn(keys_image.shape).to(keys_image.device) |
| if torch.equal(queries_text, old_queries_text): |
| print("Queries text are equal") |
| if torch.equal(keys_text, old_keys_text): |
| print("Keys text are equal") |
| if torch.equal(queries_image, old_queries_image): |
| print("Queries image are equal") |
| if torch.equal(keys_image, old_keys_image): |
| print("Keys image are equal") |
|
|
| |
| softmax, logits = self.forward(queries_text, queries_image, keys_text, keys_image) |
|
|
| softmax = softmax.squeeze() |
| real_labels = real_labels.squeeze() |
| logits = logits.squeeze() |
| real_labels = real_labels.float() |
|
|
| if real_labels.dim() < 3: |
| real_labels = real_labels.unsqueeze(0) |
| softmax = softmax.unsqueeze(0) |
| logits = logits.unsqueeze(0) |
|
|
| temp_real_labels = [] |
| temp_logits = [] |
| global_padding = 0 |
| for batch_l, batch_r in zip(logits, real_labels): |
| padding = torch.nonzero(batch_r[0] == -100) |
| if padding.nelement() == 0: |
| temp_real_labels.append(batch_r) |
| temp_logits.append(batch_l) |
| continue |
| global_padding = global_padding + padding.nelement() |
| padding_index = padding[0] |
| temp_r = batch_r.clone() |
| temp_r[:, padding_index:] = 0 |
| temp_l = batch_l.clone() |
| temp_l[:, padding_index:] = -100 |
| temp_real_labels.append(temp_r) |
| temp_logits.append(temp_l) |
|
|
| for_loss_real_labels = torch.stack(temp_real_labels).float() |
| for_loss_logits = torch.stack(temp_logits) |
| loss = self.loss_method(for_loss_logits.mT, for_loss_real_labels.mT) |
|
|
| batched_precision = [] |
| batched_accuracy = [] |
| batched_recall = [] |
|
|
| for batch_s, batch_r in zip(softmax, real_labels): |
| padding = torch.nonzero(batch_r[0] == -100) |
| if padding.nelement() > 0: |
| padding_index = padding[0] |
| batch_r = batch_r[:, :padding_index] |
| batch_s = batch_s[:, :padding_index] |
| max_indices = batch_s.argmax(dim=0) |
| |
| target_index = batch_r.argmax(dim=0) |
| |
| subtraction = max_indices - target_index |
| |
| different_values = torch.count_nonzero(subtraction) |
| |
| |
| |
| samples = batch_s.shape[1] * batch_s.shape[0] |
|
|
| TP = len(target_index) - different_values |
| FP = different_values |
| FN = different_values |
| TN = samples - TP - FP - FN |
|
|
| precision = TP / (TP + FP) |
| accuracy = (TP + TN) / samples |
| recall = TP / (TP + FN) |
|
|
| batched_precision.append(precision.item()) |
| batched_accuracy.append(accuracy.item()) |
| batched_recall.append(recall.item()) |
|
|
| precision = sum(batched_precision) / len(batched_precision) |
| accuracy = sum(batched_accuracy) / len(batched_accuracy) |
| recall = sum(batched_recall) / len(batched_recall) |
|
|
| self.train_sum_precision += precision |
| self.train_sum_accuracy += accuracy |
| self.train_sum_recall += recall |
| self.train_sum_runs += 1 |
| self.log("train_loss", loss, on_epoch=True, on_step=False, prog_bar=True, logger=True) |
| self.log("train_precision", precision, on_epoch=True, on_step=False, prog_bar=True, logger=True) |
| self.log("train_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True, logger=True) |
| self.log("train_recall", recall, on_epoch=True, on_step=False, prog_bar=True, logger=True) |
|
|
| return loss |
|
|
| def on_train_epoch_end(self) -> None: |
| self.log("train_precision_epoch", self.train_sum_precision / self.train_sum_runs) |
| self.log("train_accuracy_epoch", self.train_sum_accuracy / self.train_sum_runs) |
| self.log("train_recall_epoch", self.train_sum_recall / self.train_sum_runs) |
| self.train_sum_precision = 0 |
| self.train_sum_accuracy = 0 |
| self.train_sum_recall = 0 |
| self.train_sum_runs = 0 |
|
|
| def configure_optimizers(self): |
| self.my_optimizer = torch.optim.Adam(params=self.parameters(), lr=self.learning_rate) |
| optimizer = self.my_optimizer |
| """self.my_scheduler = torch.optim.lr_scheduler.CyclicLR(self.my_optimizer, base_lr=0.01, max_lr=0.05,step_size_up=100,cycle_momentum=False) |
| scheduler = { |
| 'scheduler': self.my_scheduler, |
| 'interval': 'step', |
| 'frequency': 1, |
| 'name': 'learning_rate' |
| }""" |
| return [optimizer] |
|
|
| def validation_step(self, val_batch, batch_idx): |
| |
| queries = val_batch['queries'] |
| keys = val_batch['keys'] |
| real_labels = val_batch['real_index'] |
|
|
| keys_text = [] |
| keys_image = [] |
| for batch in keys: |
| temp_key_text = [] |
| temp_key_image = [] |
|
|
| for key_text, key_image in batch: |
| temp_key_text.append(key_text) |
| temp_key_image.append(key_image) |
| keys_text.append(torch.stack(temp_key_text)) |
| keys_image.append(torch.stack(temp_key_image)) |
|
|
| queries_text = [] |
| queries_image = [] |
| for batch in queries: |
| temp_query_text = [] |
| temp_query_image = [] |
| for query_text, query_image in batch: |
| temp_query_text.append(query_text) |
| temp_query_image.append(query_image) |
| queries_text.append(torch.stack(temp_query_text)) |
| queries_image.append(torch.stack(temp_query_image)) |
|
|
| queries_text = torch.stack(queries_text) |
| queries_image = torch.stack(queries_image) |
| keys_text = torch.stack(keys_text) |
| keys_image = torch.stack(keys_image) |
|
|
| |
| softmax, logits = self.forward(queries_text, queries_image, keys_text, keys_image) |
|
|
| softmax = softmax.squeeze() |
| real_labels = real_labels.squeeze() |
|
|
| if real_labels.dim() < 3: |
| real_labels = real_labels.unsqueeze(0) |
| softmax = softmax.unsqueeze(0) |
| logits = logits.unsqueeze(0) |
|
|
| temp_real_labels = [] |
| temp_logits = [] |
| for batch_l, batch_r in zip(logits, real_labels): |
| padding = torch.nonzero(batch_r[0] == -100) |
| if padding.nelement() == 0: |
| continue |
| padding_index = padding[0] |
| temp_r = batch_r.clone() |
| temp_r[:, padding_index:] = 0 |
| temp_l = batch_l.clone() |
| temp_l[:, padding_index:] = -100 |
| temp_real_labels.append(temp_r) |
| temp_logits.append(temp_l) |
|
|
| if padding.nelement() > 0: |
| for_loss_real_labels = torch.stack(temp_real_labels).float() |
| for_loss_logits = torch.stack(temp_logits) |
| loss = self.loss_method(for_loss_logits.mT, for_loss_real_labels.mT) |
| else: |
| loss = self.loss_method(logits.mT, real_labels.mT) |
|
|
| if loss < 0: |
| print("Padding: ", padding.nelement()) |
| print("Loss: ", loss) |
| print("Logits: ", logits) |
| print("Real labels: ", real_labels) |
| exit() |
|
|
| batched_precision = [] |
| batched_accuracy = [] |
| batched_recall = [] |
| for batch_s, batch_r in zip(softmax, real_labels): |
| padding = torch.nonzero(batch_r[0] == -100) |
| if padding.nelement() > 0: |
| padding_index = padding[0] |
| batch_r = batch_r[:, :padding_index] |
| batch_s = batch_s[:, :padding_index] |
| max_indices = batch_s.argmax(dim=0) |
| |
| target_index = batch_r.argmax(dim=0) |
| |
| subtraction = max_indices - target_index |
| |
| different_values = torch.count_nonzero(subtraction) |
| |
| |
| |
| samples = batch_s.shape[1] * batch_s.shape[0] |
| TP = len(target_index) - different_values |
| FP = different_values |
| FN = different_values |
| TN = samples - TP - FP - FN |
|
|
| precision = TP / (TP + FP) |
| accuracy = (TP + TN) / samples |
| recall = TP / (TP + FN) |
|
|
| batched_precision.append(precision.item()) |
| batched_accuracy.append(accuracy.item()) |
| batched_recall.append(recall.item()) |
|
|
| precision = sum(batched_precision) / len(batched_precision) |
| accuracy = sum(batched_accuracy) / len(batched_accuracy) |
| recall = sum(batched_recall) / len(batched_recall) |
| self.val_sum_precision += precision |
| self.val_sum_accuracy += accuracy |
| self.val_sum_recall += recall |
| self.val_sum_runs += 1 |
| self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True, logger=True) |
| self.log("val_precision", precision, on_epoch=True, on_step=False, prog_bar=True, logger=True) |
| self.log("val_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True, logger=True) |
| self.log("val_recall", recall, on_epoch=True, on_step=False, prog_bar=True, logger=True) |
|
|
| def on_validation_epoch_end(self) -> None: |
| self.log("val_precision_epoch", self.val_sum_precision / self.val_sum_runs) |
| self.log("val_accuracy_epoch", self.val_sum_accuracy / self.val_sum_runs) |
| self.log("val_recall_epoch", self.val_sum_recall / self.val_sum_runs) |
| self.val_sum_precision = 0 |
| self.val_sum_accuracy = 0 |
| self.val_sum_recall = 0 |
| self.val_sum_runs = 0 |
|
|
|
|
| if __name__ == '__main__': |
|
|
| if len(sys.argv) > 1: |
| print("Using arguments") |
| batch_size = int(sys.argv[1]) |
| learning_rate = float(sys.argv[2]) |
| epochs = int(sys.argv[3]) |
| if sys.argv[4] == "True": |
| wandb_flag = True |
| else: |
| wandb_flag = False |
| if sys.argv[5] == "True": |
| find_lr = True |
| else: |
| find_lr = False |
| unfreeze = int(sys.argv[6]) |
|
|
| else: |
| print("Using default values") |
| batch_size = 500 |
| learning_rate = 0.01 |
| epochs = 50 |
| wandb_flag = True |
| find_lr = False |
| unfreeze = 10 |
| random_text = False |
| random_everything = False |
| random_images = False |
| fixed_text = False |
|
|
| print("Batch size: ", batch_size) |
| print("Learning rate: ", learning_rate) |
| print("Epochs: ", epochs) |
| print("Wandb flag: ", wandb_flag) |
| print("Find lr: ", find_lr) |
| print("Unfreeze: ", unfreeze) |
|
|
| train_path = "./recipe_dataset_3500_real_1.pkl" |
| val_path = "./recipe_dataset_3500_real_2.pkl" |
| train = pickle.load(open(train_path, "rb")) |
| val = pickle.load(open(val_path, "rb")) |
|
|
| if "wrong" in train_path and "wrong" in val_path: |
| print("Using dataset with false positives") |
| string_wrong = "WRONG_" |
| elif "wrong" in train_path or "wrong" in val_path: |
| raise ValueError("One of the datasets is wrong") |
| else: |
| print("Using normal dataset") |
| string_wrong = "" |
| if random_text: |
| string_wrong += "RANDOM_TEXT_" |
| elif random_everything: |
| string_wrong += "RANDOM_EVERYTHING_" |
| elif random_images: |
| string_wrong += "RANDOM_IMAGES_" |
| elif fixed_text: |
| string_wrong += "FIXED_TEXT_" |
|
|
|
|
|
|
| |
| for batch in train: |
| batch.pop('ids_queries') |
| batch.pop('ids_keys') |
|
|
| for batch in val: |
| batch.pop('ids_queries') |
| batch.pop('ids_keys') |
|
|
| train_dataset = DataLoader(train, num_workers=0, shuffle=False, batch_size=batch_size) |
| print("Train dataset size:", len(train_dataset)) |
|
|
| val_dataset = DataLoader(val, num_workers=0, shuffle=False, batch_size=batch_size) |
| print("Val dataset size:", len(val_dataset)) |
|
|
| model = SoftAttention(learning_rate=learning_rate, batch_size=batch_size, unfreeze=unfreeze, |
| random_text=random_text, random_everything=random_everything, fixed_text=fixed_text, |
| random_images=random_images) |
|
|
| lr_monitor = LearningRateMonitor(logging_interval='step') |
| if wandb_flag: |
| run_name = f"{string_wrong}MORE_RECIPES_{len(train_dataset)}_batch_{batch_size}_lr_{learning_rate}_epochs_{epochs}_unfreeze_{unfreeze}" |
| wandb_logger = WandbLogger(project='reference_training', name=run_name, log_model="all") |
| wandb_logger.experiment.config["batch_size"] = batch_size |
| wandb_logger.experiment.config["max_epochs"] = epochs |
| wandb_logger.experiment.config["learning_rate"] = learning_rate |
|
|
| trainer = L.Trainer(max_epochs=epochs, detect_anomaly=False, logger=wandb_logger, callbacks=[lr_monitor]) |
| else: |
| trainer = L.Trainer(max_epochs=epochs, default_root_dir="./", callbacks=[lr_monitor]) |
|
|
| if find_lr: |
| tuner = Tuner(trainer) |
| lr_finder = tuner.lr_find(model, train_dataloaders=train_dataset, val_dataloaders=val_dataset) |
| print(lr_finder.suggestion()) |
| else: |
| trainer.fit(model, train_dataloaders=train_dataset, val_dataloaders=val_dataset) |
| |
|
|
| if wandb_flag: |
| wandb.finish() |
|
|