File size: 3,346 Bytes
a37f5d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# -*- coding: UTF-8 -*-
import faiss
import torch
from torch import nn
import logging
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset
# from prettytable import PrettyTable
import warnings
import cv2
from os.path import join
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# import matplotlib.pyplot as plt

def match_batch_tensor(fm1, fm2, trainflag, grid_size, T2=0.7):
    '''
    fm1: (l,D) 529,768
    fm2: (N,l,D) 100,529,768
    mask1: (l)
    mask2: (N,l)
    '''
    M = torch.matmul(fm2, fm1.T)

    max1 = torch.argmax(M, dim=1)
    max2 = torch.argmax(M, dim=2)
    m = max2[torch.arange(M.shape[0]).reshape((-1, 1)), max1]
    valid = torch.arange(M.shape[-1]).repeat((M.shape[0], 1)).cuda() == m
    scores = torch.zeros(fm2.shape[0]).cuda()

    for i in range(fm2.shape[0]):
        idx1 = torch.nonzero(valid[i, :]).squeeze()
        idx2 = max1[i, :][idx1]
        assert idx1.shape == idx2.shape

        if len(idx1.shape) > 0:
            # Calculate cosine similarity and apply threshold
            cos_similarity = torch.sum(fm1[idx1] * fm2[i][idx2], dim=1)
            valid_pairs = cos_similarity > T2
            idx1 = idx1[valid_pairs]
            idx2 = idx2[valid_pairs]

        if trainflag:
            if len(idx1.shape) > 0:
                similarity = torch.mean(torch.sum(fm1[idx1] * fm2[i][idx2], dim=1), dim=0)
            else:
                print("No mutual nearest neighbors!")
                similarity = torch.mean(torch.sum(fm1 * fm2[i], dim=1), dim=0)
            return similarity

        else:
            if len(idx1.shape) < 1:
                scores[i] = 0
            else:
                scores[i] = len(idx1)
    return scores

def local_sim(features_1, features_2, trainflag=False):
    B, Num, C = features_2.shape
    if trainflag:
        queries = features_1
        preds = features_2
        similarity = torch.zeros(B).cuda()
        for i in range(B):
            query,pred = queries[i],preds[i].unsqueeze(0)
            similarity[i] = match_batch_tensor(query, pred, trainflag, grid_size=(61,61))
        return similarity
    else:
        query = features_1
        preds = features_2
        scores = match_batch_tensor(query, preds,trainflag, grid_size=(61,61))
        return scores
    
def rerank(predictions, queries_local_features, database_local_features):
    pred2 = []
    print("reranking...")
    for query_index, pred in enumerate(tqdm(predictions)):
        query_local_features = torch.tensor(queries_local_features[query_index]).cuda()
        positives_local_features = torch.tensor(database_local_features[pred]).cuda()
        rerank_index = local_sim(query_local_features, positives_local_features, trainflag=False)
        rerank_index_sorted = rerank_index.cpu().numpy().argsort()[::-1]
        pred2.append(predictions[query_index][rerank_index_sorted])
    return np.array(pred2)
    
def run_rerank(queries_features, database_features, q_local_list, r_local_list, recall_values = [1,]):

    faiss_index = faiss.IndexFlatL2(8448)
    faiss_index.add(database_features)

    distances, predictions = faiss_index.search(queries_features, max(recall_values))

    # rerank
    predictions2 = rerank(predictions, q_local_list, r_local_list)

    return predictions2