Spaces:
Sleeping
Sleeping
| # -*- coding: UTF-8 -*- | |
| import faiss | |
| import torch | |
| from torch import nn | |
| import logging | |
| import numpy as np | |
| from tqdm import tqdm | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data.dataset import Subset | |
| # from prettytable import PrettyTable | |
| import warnings | |
| import cv2 | |
| from os.path import join | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # import matplotlib.pyplot as plt | |
| def match_batch_tensor(fm1, fm2, trainflag, grid_size, T2=0.7): | |
| ''' | |
| fm1: (l,D) 529,768 | |
| fm2: (N,l,D) 100,529,768 | |
| mask1: (l) | |
| mask2: (N,l) | |
| ''' | |
| M = torch.matmul(fm2, fm1.T) | |
| max1 = torch.argmax(M, dim=1) | |
| max2 = torch.argmax(M, dim=2) | |
| m = max2[torch.arange(M.shape[0]).reshape((-1, 1)), max1] | |
| valid = torch.arange(M.shape[-1]).repeat((M.shape[0], 1)).cuda() == m | |
| scores = torch.zeros(fm2.shape[0]).cuda() | |
| for i in range(fm2.shape[0]): | |
| idx1 = torch.nonzero(valid[i, :]).squeeze() | |
| idx2 = max1[i, :][idx1] | |
| assert idx1.shape == idx2.shape | |
| if len(idx1.shape) > 0: | |
| # Calculate cosine similarity and apply threshold | |
| cos_similarity = torch.sum(fm1[idx1] * fm2[i][idx2], dim=1) | |
| valid_pairs = cos_similarity > T2 | |
| idx1 = idx1[valid_pairs] | |
| idx2 = idx2[valid_pairs] | |
| if trainflag: | |
| if len(idx1.shape) > 0: | |
| similarity = torch.mean(torch.sum(fm1[idx1] * fm2[i][idx2], dim=1), dim=0) | |
| else: | |
| print("No mutual nearest neighbors!") | |
| similarity = torch.mean(torch.sum(fm1 * fm2[i], dim=1), dim=0) | |
| return similarity | |
| else: | |
| if len(idx1.shape) < 1: | |
| scores[i] = 0 | |
| else: | |
| scores[i] = len(idx1) | |
| return scores | |
| def local_sim(features_1, features_2, trainflag=False): | |
| B, Num, C = features_2.shape | |
| if trainflag: | |
| queries = features_1 | |
| preds = features_2 | |
| similarity = torch.zeros(B).cuda() | |
| for i in range(B): | |
| query,pred = queries[i],preds[i].unsqueeze(0) | |
| similarity[i] = match_batch_tensor(query, pred, trainflag, grid_size=(61,61)) | |
| return similarity | |
| else: | |
| query = features_1 | |
| preds = features_2 | |
| scores = match_batch_tensor(query, preds,trainflag, grid_size=(61,61)) | |
| return scores | |
| def rerank(predictions, queries_local_features, database_local_features): | |
| pred2 = [] | |
| print("reranking...") | |
| for query_index, pred in enumerate(tqdm(predictions)): | |
| query_local_features = torch.tensor(queries_local_features[query_index]).cuda() | |
| positives_local_features = torch.tensor(database_local_features[pred]).cuda() | |
| rerank_index = local_sim(query_local_features, positives_local_features, trainflag=False) | |
| rerank_index_sorted = rerank_index.cpu().numpy().argsort()[::-1] | |
| pred2.append(predictions[query_index][rerank_index_sorted]) | |
| return np.array(pred2) | |
| def run_rerank(queries_features, database_features, q_local_list, r_local_list, recall_values = [1,]): | |
| faiss_index = faiss.IndexFlatL2(8448) | |
| faiss_index.add(database_features) | |
| distances, predictions = faiss_index.search(queries_features, max(recall_values)) | |
| # rerank | |
| predictions2 = rerank(predictions, q_local_list, r_local_list) | |
| return predictions2 |