| | import torch |
| | import torch.nn as nn |
| | from prettytable import PrettyTable |
| | from torch.nn.modules.activation import Tanh |
| | import copy |
| | import logging |
| | logger = logging.getLogger(__name__) |
| | from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, |
| | RobertaConfig, RobertaModel, RobertaTokenizer) |
| | def whitening_torch_final(embeddings): |
| | mu = torch.mean(embeddings, dim=0, keepdim=True) |
| | cov = torch.mm((embeddings - mu).t(), embeddings - mu) |
| | u, s, vt = torch.svd(cov) |
| | W = torch.mm(u, torch.diag(1/torch.sqrt(s))) |
| | embeddings = torch.mm(embeddings - mu, W) |
| | return embeddings |
| |
|
| | class BaseModel(nn.Module): |
| | def __init__(self, ): |
| | super().__init__() |
| | |
| | def model_parameters(self): |
| | table = PrettyTable() |
| | table.field_names = ["Layer Name", "Output Shape", "Param #"] |
| | table.align["Layer Name"] = "l" |
| | table.align["Output Shape"] = "r" |
| | table.align["Param #"] = "r" |
| | for name, parameters in self.named_parameters(): |
| | if parameters.requires_grad: |
| | table.add_row([name, str(list(parameters.shape)), parameters.numel()]) |
| | return table |
| | class Model(BaseModel): |
| | def __init__(self, encoder): |
| | super(Model, self).__init__() |
| | self.encoder = encoder |
| | |
| | def forward(self, code_inputs=None, nl_inputs=None): |
| | |
| | if code_inputs is not None: |
| | outputs = self.encoder(code_inputs,attention_mask=code_inputs.ne(1))[0] |
| | outputs = (outputs*code_inputs.ne(1)[:,:,None]).sum(1)/code_inputs.ne(1).sum(-1)[:,None] |
| | return torch.nn.functional.normalize(outputs, p=2, dim=1) |
| | else: |
| | outputs = self.encoder(nl_inputs,attention_mask=nl_inputs.ne(1))[0] |
| | outputs = (outputs*nl_inputs.ne(1)[:,:,None]).sum(1)/nl_inputs.ne(1).sum(-1)[:,None] |
| | return torch.nn.functional.normalize(outputs, p=2, dim=1) |
| | |
| | |
| | class Multi_Loss_CoCoSoDa( BaseModel): |
| |
|
| | def __init__(self, base_encoder, args, mlp=False): |
| | super(Multi_Loss_CoCoSoDa, self).__init__() |
| |
|
| | self.K = args.moco_k |
| | self.m = args.moco_m |
| | self.T = args.moco_t |
| | dim= args.moco_dim |
| |
|
| | |
| | |
| | self.code_encoder_q = base_encoder |
| | self.code_encoder_k = copy.deepcopy(base_encoder) |
| | self.nl_encoder_q = base_encoder |
| | |
| | self.nl_encoder_k = copy.deepcopy(self.nl_encoder_q) |
| | self.mlp = mlp |
| | self.time_score= args.time_score |
| | self.do_whitening = args.do_whitening |
| | self.do_ineer_loss = args.do_ineer_loss |
| | self.agg_way = args.agg_way |
| | self.args = args |
| |
|
| | for param_q, param_k in zip(self.code_encoder_q.parameters(), self.code_encoder_k.parameters()): |
| | param_k.data.copy_(param_q.data) |
| | param_k.requires_grad = False |
| |
|
| | for param_q, param_k in zip(self.nl_encoder_q.parameters(), self.nl_encoder_k.parameters()): |
| | param_k.data.copy_(param_q.data) |
| | param_k.requires_grad = False |
| |
|
| | |
| | torch.manual_seed(3047) |
| | torch.cuda.manual_seed(3047) |
| | self.register_buffer("code_queue", torch.randn(dim,self.K )) |
| | self.code_queue = nn.functional.normalize(self.code_queue, dim=0) |
| | self.register_buffer("code_queue_ptr", torch.zeros(1, dtype=torch.long)) |
| | |
| | self.register_buffer("masked_code_queue", torch.randn(dim, self.K )) |
| | self.masked_code_queue = nn.functional.normalize(self.masked_code_queue, dim=0) |
| | self.register_buffer("masked_code_queue_ptr", torch.zeros(1, dtype=torch.long)) |
| |
|
| |
|
| | |
| | self.register_buffer("nl_queue", torch.randn(dim, self.K )) |
| | self.nl_queue = nn.functional.normalize(self.nl_queue, dim=0) |
| | self.register_buffer("nl_queue_ptr", torch.zeros(1, dtype=torch.long)) |
| | |
| | self.register_buffer("masked_nl_queue", torch.randn(dim, self.K )) |
| | self.masked_nl_queue= nn.functional.normalize(self.masked_nl_queue, dim=0) |
| | self.register_buffer("masked_nl_queue_ptr", torch.zeros(1, dtype=torch.long)) |
| |
|
| |
|
| |
|
| |
|
| | @torch.no_grad() |
| | def _momentum_update_key_encoder(self): |
| | """ |
| | Momentum update of the key encoder |
| | % key encoder的Momentum update |
| | """ |
| | for param_q, param_k in zip(self.code_encoder_q.parameters(), self.code_encoder_k.parameters()): |
| | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) |
| | for param_q, param_k in zip(self.nl_encoder_q.parameters(), self.nl_encoder_k.parameters()): |
| | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) |
| | if self.mlp: |
| | for param_q, param_k in zip(self.code_encoder_q_fc.parameters(), self.code_encoder_k_fc.parameters()): |
| | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) |
| | for param_q, param_k in zip(self.nl_encoder_q_fc.parameters(), self.nl_encoder_k_fc.parameters()): |
| | param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) |
| |
|
| | @torch.no_grad() |
| | def _dequeue_and_enqueue(self, keys, option='code'): |
| | |
| | |
| |
|
| | batch_size = keys.shape[0] |
| | if option == 'code': |
| | code_ptr = int(self.code_queue_ptr) |
| | assert self.K % batch_size == 0 |
| |
|
| | |
| | try: |
| | self.code_queue[:, code_ptr:code_ptr + batch_size] = keys.T |
| | except: |
| | print(code_ptr) |
| | print(batch_size) |
| | print(keys.shape) |
| | exit(111) |
| | code_ptr = (code_ptr + batch_size) % self.K |
| |
|
| | self.code_queue_ptr[0] = code_ptr |
| | |
| | elif option == 'masked_code': |
| | masked_code_ptr = int(self.masked_code_queue_ptr) |
| | assert self.K % batch_size == 0 |
| |
|
| | |
| | try: |
| | self.masked_code_queue[:, masked_code_ptr:masked_code_ptr + batch_size] = keys.T |
| | except: |
| | print(masked_code_ptr) |
| | print(batch_size) |
| | print(keys.shape) |
| | exit(111) |
| | masked_code_ptr = (masked_code_ptr + batch_size) % self.K |
| |
|
| | self.masked_code_queue_ptr[0] = masked_code_ptr |
| | |
| | elif option == 'nl': |
| |
|
| | nl_ptr = int(self.nl_queue_ptr) |
| | assert self.K % batch_size == 0 |
| |
|
| | |
| | self.nl_queue[:, nl_ptr:nl_ptr + batch_size] = keys.T |
| | nl_ptr = (nl_ptr + batch_size) % self.K |
| |
|
| | self.nl_queue_ptr[0] = nl_ptr |
| | elif option == 'masked_nl': |
| |
|
| | masked_nl_ptr = int(self.masked_nl_queue_ptr) |
| | assert self.K % batch_size == 0 |
| |
|
| | |
| | self.masked_nl_queue[:, masked_nl_ptr:masked_nl_ptr + batch_size] = keys.T |
| | masked_nl_ptr = (masked_nl_ptr + batch_size) % self.K |
| |
|
| | self.masked_nl_queue_ptr[0] = masked_nl_ptr |
| |
|
| | |
| |
|
| | def forward(self, source_code_q, source_code_k, nl_q,nl_k): |
| | """ |
| | Input: |
| | im_q: a batch of query images |
| | im_k: a batch of key images |
| | Output: |
| | logits, targets |
| | """ |
| | if not self.args.do_multi_lang_continue_pre_train: |
| | |
| | outputs = self.code_encoder_q(source_code_q, attention_mask=source_code_q.ne(1))[0] |
| | code_q = (outputs*source_code_q.ne(1)[:,:,None]).sum(1)/source_code_q.ne(1).sum(-1)[:,None] |
| | code_q = torch.nn.functional.normalize(code_q, p=2, dim=1) |
| | |
| | outputs= self.nl_encoder_q(nl_q, attention_mask=nl_q.ne(1))[0] |
| | nl_q = (outputs*nl_q.ne(1)[:,:,None]).sum(1)/nl_q.ne(1).sum(-1)[:,None] |
| | nl_q = torch.nn.functional.normalize(nl_q, p=2, dim=1) |
| | code2nl_logits = torch.einsum("ab,cb->ac", code_q,nl_q ) |
| | |
| | code2nl_logits /= self.T |
| | |
| | code2nl_label = torch.arange(code2nl_logits.size(0), device=code2nl_logits.device) |
| | return code2nl_logits,code2nl_label, None, None |
| | if self.agg_way == "avg": |
| | |
| | outputs = self.code_encoder_q(source_code_q, attention_mask=source_code_q.ne(1))[0] |
| | code_q = (outputs*source_code_q.ne(1)[:,:,None]).sum(1)/source_code_q.ne(1).sum(-1)[:,None] |
| | code_q = torch.nn.functional.normalize(code_q, p=2, dim=1) |
| | |
| | outputs= self.nl_encoder_q(nl_q, attention_mask=nl_q.ne(1))[0] |
| | nl_q = (outputs*nl_q.ne(1)[:,:,None]).sum(1)/nl_q.ne(1).sum(-1)[:,None] |
| | nl_q = torch.nn.functional.normalize(nl_q, p=2, dim=1) |
| |
|
| | |
| | with torch.no_grad(): |
| | self._momentum_update_key_encoder() |
| |
|
| | |
| | |
| |
|
| | |
| | outputs = self.code_encoder_k(source_code_k, attention_mask=source_code_k.ne(1))[0] |
| | code_k = (outputs*source_code_k.ne(1)[:,:,None]).sum(1)/source_code_k.ne(1).sum(-1)[:,None] |
| | code_k = torch.nn.functional.normalize( code_k, p=2, dim=1) |
| | |
| | outputs = self.nl_encoder_k(nl_k, attention_mask=nl_k.ne(1))[0] |
| | nl_k = (outputs*nl_k.ne(1)[:,:,None]).sum(1)/nl_k.ne(1).sum(-1)[:,None] |
| | nl_k = torch.nn.functional.normalize(nl_k, p=2, dim=1) |
| |
|
| | elif self.agg_way == "cls_pooler": |
| | |
| | |
| | outputs = self.code_encoder_q(source_code_q, attention_mask=source_code_q.ne(1))[1] |
| | code_q = torch.nn.functional.normalize(code_q, p=2, dim=1) |
| | |
| | outputs= self.nl_encoder_q(nl_q, attention_mask=nl_q.ne(1))[1] |
| | nl_q = torch.nn.functional.normalize(nl_q, p=2, dim=1) |
| |
|
| | |
| | with torch.no_grad(): |
| | self._momentum_update_key_encoder() |
| |
|
| | |
| | |
| |
|
| | |
| | outputs = self.code_encoder_k(source_code_k, attention_mask=source_code_k.ne(1))[1] |
| | code_k = torch.nn.functional.normalize( code_k, p=2, dim=1) |
| | |
| | outputs = self.nl_encoder_k(nl_k, attention_mask=nl_k.ne(1))[1] |
| | nl_k = torch.nn.functional.normalize(nl_k, p=2, dim=1) |
| |
|
| | elif self.agg_way == "avg_cls_pooler": |
| | |
| | outputs = self.code_encoder_q(source_code_q, attention_mask=source_code_q.ne(1)) |
| | code_q_cls = outputs[1] |
| | outputs = outputs[0] |
| | code_q_avg = (outputs*source_code_q.ne(1)[:,:,None]).sum(1)/source_code_q.ne(1).sum(-1)[:,None] |
| | code_q = code_q_cls + code_q_avg |
| | code_q = torch.nn.functional.normalize(code_q, p=2, dim=1) |
| | |
| | outputs= self.nl_encoder_q(nl_q, attention_mask=nl_q.ne(1)) |
| | nl_q_cls = outputs[1] |
| | outputs= outputs[0] |
| | nl_q_avg = (outputs*nl_q.ne(1)[:,:,None]).sum(1)/nl_q.ne(1).sum(-1)[:,None] |
| | nl_q = nl_q_avg + nl_q_cls |
| | nl_q = torch.nn.functional.normalize(nl_q, p=2, dim=1) |
| |
|
| | |
| | with torch.no_grad(): |
| | self._momentum_update_key_encoder() |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | outputs = self.code_encoder_k(source_code_k, attention_mask=source_code_k.ne(1)) |
| | code_k_cls = outputs[1] |
| | outputs = outputs[0] |
| | code_k_avg = (outputs*source_code_k.ne(1)[:,:,None]).sum(1)/source_code_k.ne(1).sum(-1)[:,None] |
| | code_k = code_k_cls + code_k_avg |
| | code_k = torch.nn.functional.normalize( code_k, p=2, dim=1) |
| | |
| | outputs = self.nl_encoder_k(nl_k, attention_mask=nl_k.ne(1)) |
| | nl_k_cls = outputs[1] |
| | outputs = outputs[0] |
| | nl_k_avg = (outputs*nl_k.ne(1)[:,:,None]).sum(1)/nl_k.ne(1).sum(-1)[:,None] |
| | nl_k = nl_k_cls + nl_k_avg |
| | nl_k = torch.nn.functional.normalize(nl_k, p=2, dim=1) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | code2nl_pos = torch.einsum('nc,bc->nb', [code_q, nl_q]) |
| | |
| | code2nl_neg = torch.einsum('nc,ck->nk', [code_q, self.nl_queue.clone().detach()]) |
| | |
| | code2nl_logits = torch.cat([self.time_score*code2nl_pos, code2nl_neg], dim=1) |
| | |
| | code2nl_logits /= self.T |
| | |
| | code2nl_label = torch.arange(code2nl_logits.size(0), device=code2nl_logits.device) |
| |
|
| | |
| | code2maskednl_pos = torch.einsum('nc,bc->nb', [code_q, nl_k]) |
| | |
| | code2maskednl_neg = torch.einsum('nc,ck->nk', [code_q, self.masked_nl_queue.clone().detach()]) |
| | |
| | code2maskednl_logits = torch.cat([self.time_score*code2maskednl_pos, code2maskednl_neg], dim=1) |
| | |
| | code2maskednl_logits /= self.T |
| | |
| | code2maskednl_label = torch.arange(code2maskednl_logits.size(0), device=code2maskednl_logits.device) |
| |
|
| | |
| | |
| | nl2code_pos = torch.einsum('nc,bc->nb', [nl_q, code_q]) |
| | |
| | nl2code_neg = torch.einsum('nc,ck->nk', [nl_q, self.code_queue.clone().detach()]) |
| | |
| | nl2code_logits = torch.cat([self.time_score*nl2code_pos, nl2code_neg], dim=1) |
| | |
| | nl2code_logits /= self.T |
| | |
| | nl2code_label = torch.arange(nl2code_logits.size(0), device=nl2code_logits.device) |
| |
|
| | |
| | |
| | nl2maskedcode_pos = torch.einsum('nc,bc->nb', [nl_q, code_k]) |
| | |
| | nl2maskedcode_neg = torch.einsum('nc,ck->nk', [nl_q, self.masked_code_queue.clone().detach()]) |
| | |
| | nl2maskedcode_logits = torch.cat([self.time_score*nl2maskedcode_pos, nl2maskedcode_neg], dim=1) |
| | |
| | nl2maskedcode_logits /= self.T |
| | |
| | nl2maskedcode_label = torch.arange(nl2maskedcode_logits.size(0), device=nl2maskedcode_logits.device) |
| | |
| | |
| | inter_logits = torch.cat((code2nl_logits, code2maskednl_logits, nl2code_logits ,nl2maskedcode_logits ), dim=0) |
| |
|
| | |
| | |
| | inter_labels = torch.cat((code2nl_label, code2maskednl_label, nl2code_label, nl2maskedcode_label), dim=0) |
| |
|
| | if self.do_ineer_loss: |
| | |
| | |
| | code2maskedcode_pos = torch.einsum('nc,bc->nb', [code_q, code_k]) |
| | |
| | code2maskedcode_neg = torch.einsum('nc,ck->nk', [code_q, self.masked_code_queue.clone().detach()]) |
| | |
| | code2maskedcode_logits = torch.cat([self.time_score*code2maskedcode_pos, code2maskedcode_neg], dim=1) |
| | |
| | code2maskedcode_logits /= self.T |
| | |
| | code2maskedcode_label = torch.arange(code2maskedcode_logits.size(0), device=code2maskedcode_logits.device) |
| |
|
| |
|
| | |
| | |
| | nl2maskednl_pos = torch.einsum('nc,bc->nb', [nl_q, nl_k]) |
| | |
| | nl2maskednl_neg = torch.einsum('nc,ck->nk', [nl_q, self.masked_nl_queue.clone().detach()]) |
| | |
| | nl2maskednl_logits = torch.cat([self.time_score*nl2maskednl_pos, nl2maskednl_neg], dim=1) |
| | |
| | nl2maskednl_logits /= self.T |
| | |
| | nl2maskednl_label = torch.arange(nl2maskednl_logits.size(0), device=nl2maskednl_logits.device) |
| | |
| |
|
| | |
| | inter_logits = torch.cat((inter_logits, code2maskedcode_logits, nl2maskednl_logits), dim=0) |
| |
|
| | |
| | |
| | inter_labels = torch.cat(( inter_labels, code2maskedcode_label, nl2maskednl_label ), dim=0) |
| |
|
| |
|
| | |
| | self._dequeue_and_enqueue(code_q, option='code') |
| | self._dequeue_and_enqueue(nl_q, option='nl') |
| | self._dequeue_and_enqueue(code_k, option='masked_code') |
| | self._dequeue_and_enqueue(nl_k, option='masked_nl') |
| |
|
| | return inter_logits, inter_labels, code_q, nl_q |
| |
|
| |
|