python-FRED / FoL /reranking.py
CMalone-Jupiter's picture
Upload folder using huggingface_hub
a37f5d3 verified
# -*- coding: UTF-8 -*-
import faiss
import torch
from torch import nn
import logging
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset
# from prettytable import PrettyTable
import warnings
import cv2
from os.path import join
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# import matplotlib.pyplot as plt
def match_batch_tensor(fm1, fm2, trainflag, grid_size, T2=0.7):
'''
fm1: (l,D) 529,768
fm2: (N,l,D) 100,529,768
mask1: (l)
mask2: (N,l)
'''
M = torch.matmul(fm2, fm1.T)
max1 = torch.argmax(M, dim=1)
max2 = torch.argmax(M, dim=2)
m = max2[torch.arange(M.shape[0]).reshape((-1, 1)), max1]
valid = torch.arange(M.shape[-1]).repeat((M.shape[0], 1)).cuda() == m
scores = torch.zeros(fm2.shape[0]).cuda()
for i in range(fm2.shape[0]):
idx1 = torch.nonzero(valid[i, :]).squeeze()
idx2 = max1[i, :][idx1]
assert idx1.shape == idx2.shape
if len(idx1.shape) > 0:
# Calculate cosine similarity and apply threshold
cos_similarity = torch.sum(fm1[idx1] * fm2[i][idx2], dim=1)
valid_pairs = cos_similarity > T2
idx1 = idx1[valid_pairs]
idx2 = idx2[valid_pairs]
if trainflag:
if len(idx1.shape) > 0:
similarity = torch.mean(torch.sum(fm1[idx1] * fm2[i][idx2], dim=1), dim=0)
else:
print("No mutual nearest neighbors!")
similarity = torch.mean(torch.sum(fm1 * fm2[i], dim=1), dim=0)
return similarity
else:
if len(idx1.shape) < 1:
scores[i] = 0
else:
scores[i] = len(idx1)
return scores
def local_sim(features_1, features_2, trainflag=False):
B, Num, C = features_2.shape
if trainflag:
queries = features_1
preds = features_2
similarity = torch.zeros(B).cuda()
for i in range(B):
query,pred = queries[i],preds[i].unsqueeze(0)
similarity[i] = match_batch_tensor(query, pred, trainflag, grid_size=(61,61))
return similarity
else:
query = features_1
preds = features_2
scores = match_batch_tensor(query, preds,trainflag, grid_size=(61,61))
return scores
def rerank(predictions, queries_local_features, database_local_features):
pred2 = []
print("reranking...")
for query_index, pred in enumerate(tqdm(predictions)):
query_local_features = torch.tensor(queries_local_features[query_index]).cuda()
positives_local_features = torch.tensor(database_local_features[pred]).cuda()
rerank_index = local_sim(query_local_features, positives_local_features, trainflag=False)
rerank_index_sorted = rerank_index.cpu().numpy().argsort()[::-1]
pred2.append(predictions[query_index][rerank_index_sorted])
return np.array(pred2)
def run_rerank(queries_features, database_features, q_local_list, r_local_list, recall_values = [1,]):
faiss_index = faiss.IndexFlatL2(8448)
faiss_index.add(database_features)
distances, predictions = faiss_index.search(queries_features, max(recall_values))
# rerank
predictions2 = rerank(predictions, q_local_list, r_local_list)
return predictions2