# -*- coding: UTF-8 -*- import faiss import torch from torch import nn import logging import numpy as np from tqdm import tqdm from torch.utils.data import DataLoader from torch.utils.data.dataset import Subset # from prettytable import PrettyTable import warnings import cv2 from os.path import join device = 'cuda' if torch.cuda.is_available() else 'cpu' # import matplotlib.pyplot as plt def match_batch_tensor(fm1, fm2, trainflag, grid_size, T2=0.7): ''' fm1: (l,D) 529,768 fm2: (N,l,D) 100,529,768 mask1: (l) mask2: (N,l) ''' M = torch.matmul(fm2, fm1.T) max1 = torch.argmax(M, dim=1) max2 = torch.argmax(M, dim=2) m = max2[torch.arange(M.shape[0]).reshape((-1, 1)), max1] valid = torch.arange(M.shape[-1]).repeat((M.shape[0], 1)).cuda() == m scores = torch.zeros(fm2.shape[0]).cuda() for i in range(fm2.shape[0]): idx1 = torch.nonzero(valid[i, :]).squeeze() idx2 = max1[i, :][idx1] assert idx1.shape == idx2.shape if len(idx1.shape) > 0: # Calculate cosine similarity and apply threshold cos_similarity = torch.sum(fm1[idx1] * fm2[i][idx2], dim=1) valid_pairs = cos_similarity > T2 idx1 = idx1[valid_pairs] idx2 = idx2[valid_pairs] if trainflag: if len(idx1.shape) > 0: similarity = torch.mean(torch.sum(fm1[idx1] * fm2[i][idx2], dim=1), dim=0) else: print("No mutual nearest neighbors!") similarity = torch.mean(torch.sum(fm1 * fm2[i], dim=1), dim=0) return similarity else: if len(idx1.shape) < 1: scores[i] = 0 else: scores[i] = len(idx1) return scores def local_sim(features_1, features_2, trainflag=False): B, Num, C = features_2.shape if trainflag: queries = features_1 preds = features_2 similarity = torch.zeros(B).cuda() for i in range(B): query,pred = queries[i],preds[i].unsqueeze(0) similarity[i] = match_batch_tensor(query, pred, trainflag, grid_size=(61,61)) return similarity else: query = features_1 preds = features_2 scores = match_batch_tensor(query, preds,trainflag, grid_size=(61,61)) return scores def rerank(predictions, queries_local_features, database_local_features): pred2 = [] print("reranking...") for query_index, pred in enumerate(tqdm(predictions)): query_local_features = torch.tensor(queries_local_features[query_index]).cuda() positives_local_features = torch.tensor(database_local_features[pred]).cuda() rerank_index = local_sim(query_local_features, positives_local_features, trainflag=False) rerank_index_sorted = rerank_index.cpu().numpy().argsort()[::-1] pred2.append(predictions[query_index][rerank_index_sorted]) return np.array(pred2) def run_rerank(queries_features, database_features, q_local_list, r_local_list, recall_values = [1,]): faiss_index = faiss.IndexFlatL2(8448) faiss_index.add(database_features) distances, predictions = faiss_index.search(queries_features, max(recall_values)) # rerank predictions2 = rerank(predictions, q_local_list, r_local_list) return predictions2