File size: 12,715 Bytes
de15dc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
#!/usr/bin/env python3
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function

import torch
import numpy as np
import random
import os
from metrics import compute_metrics, tensor_text_to_video_metrics, tensor_video_to_text_sim
import time
import argparse
from modules.tokenization_clip import SimpleTokenizer as ClipTokenizer
from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from modules.modeling import CLIP4Clip

from util import parallel_apply, get_logger
from simple_dataloaders import SIMPLE_DATALOADER_DICT

global logger

def get_args():
    parser = argparse.ArgumentParser(description='Simplified CLIP4Clip Evaluation')
    parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")

    parser.add_argument('--val_csv', type=str, default='data/.val.csv', help='')
    parser.add_argument('--data_path', type=str, default='data/caption.pickle', help='data pickle file path')
    parser.add_argument('--features_path', type=str, default='data/videos_feature.pickle', help='feature path')

    parser.add_argument('--num_thread_reader', type=int, default=1, help='')
    parser.add_argument('--batch_size_val', type=int, default=16, help='batch size eval')
    parser.add_argument('--video_dim', type=int, default=1024, help='video feature dimension')
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    parser.add_argument('--max_words', type=int, default=32, help='')
    parser.add_argument('--max_frames', type=int, default=12, help='')
    parser.add_argument('--feature_framerate', type=int, default=1, help='')
    
    parser.add_argument('--datatype', type=str, default='msrvtt', help='data type')
    parser.add_argument('--world_size', type=int, default=1, help='number of distributed processes')
    parser.add_argument('--rank', type=int, default=0, help='distributed process rank')
    parser.add_argument('--local_rank', type=int, default=0, help='distributed process local rank')

    parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.")
    parser.add_argument('--visual_num_hidden_layers', type=int, default=12, help="Layer NO. of visual.")
    parser.add_argument('--cross_num_hidden_layers', type=int, default=4, help="Layer NO. of cross.")

    parser.add_argument('--loose_type', action='store_true', help="Default using tight type for retrieval.")
    parser.add_argument('--expand_msrvtt_sentences', action='store_true', help="")
    
    parser.add_argument('--linear_patch', type=str, default="2d", help="linear projection")
    parser.add_argument('--sim_header', type=str, default="meanP", help="choice a similarity header.")

    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")
    
    parser.add_argument("--pretrained_clip_name", default="ViT-B/32", type=str, help="Choose a CLIP version")
    parser.add_argument('--freeze_layer_num', type=int, default=0, help="Layer NO. of CLIP need to freeze.")
    parser.add_argument('--slice_framepos', type=int, default=2, help="0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.")
    
    # Additional arguments for dataloader compatibility
    parser.add_argument('--train_frame_order', type=int, default=0, choices=[0, 1, 2],
                        help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.")
    parser.add_argument('--eval_frame_order', type=int, default=0, choices=[0, 1, 2],
                        help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.")
    parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative')
    parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader')
    
    parser.add_argument('--init_model', type=str, default=None, help="Initial model.")
    parser.add_argument('--resume_model', type=int, default=-1, help="Resume train model from checkpoint.")

    args = parser.parse_args()
    return args

def set_seed_logger(args):
    global logger
    # predefining random initial seeds
    random.seed(args.seed)
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    world_size = args.world_size
    rank = args.rank
    args.rank = rank

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)

    logger = get_logger(os.path.join(args.output_dir, "log.txt"))

    if args.local_rank == 0:
        logger.info("Effective parameters:")
        for key in sorted(args.__dict__):
            logger.info("  <<< {}: {}".format(key, args.__dict__[key]))

    return args

def init_device(args, local_rank):
    global logger

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu", local_rank)

    n_gpu = torch.cuda.device_count()
    logger.info("device: {} n_gpu: {}".format(device, n_gpu))
    args.n_gpu = n_gpu

    if args.batch_size_val % args.n_gpu != 0:
        raise ValueError("Invalid batch_size_val and n_gpu parameter: {}%{}, should be == 0".format(
            args.batch_size_val, args.n_gpu))

    return device, n_gpu

def load_model(args, n_gpu, device, model_file=None):
    if model_file is None or len(model_file) == 0:
        if args.resume_model >= 0:
            model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(args.resume_model))
        elif args.init_model:
            model_file = args.init_model
        else:
            # Load pretrained model
            model = CLIP4Clip.from_pretrained("cross-base", 
                                            cache_dir=PYTORCH_PRETRAINED_BERT_CACHE, 
                                            task_config=args)
            model.to(device)
            return model

    if os.path.exists(model_file):
        model_state_dict = torch.load(model_file, map_location='cpu')
        if args.local_rank == 0:
            logger.info("Model loaded from %s", model_file)
        # Prepare model
        cache_dir = args.cache_dir if hasattr(args, 'cache_dir') and args.cache_dir else PYTORCH_PRETRAINED_BERT_CACHE
        model = CLIP4Clip.from_pretrained("cross-base", 
                                        cache_dir=cache_dir, 
                                        state_dict=model_state_dict, 
                                        task_config=args)
        model.to(device)
    else:
        logger.error("Model file not found: %s", model_file)
        model = None
    return model

def eval_epoch(args, model, test_dataloader, device, n_gpu):
    if hasattr(model, 'module'):
        model = model.module.to(device)
    else:
        model = model.to(device)

    # multi-sentences retrieval variables
    multi_sentence_ = False
    cut_off_points_, sentence_num_, video_num_ = [], -1, -1
    if hasattr(test_dataloader.dataset, 'multi_sentence_per_video') \
            and test_dataloader.dataset.multi_sentence_per_video:
        multi_sentence_ = True
        cut_off_points_ = test_dataloader.dataset.cut_off_points
        sentence_num_ = test_dataloader.dataset.sentence_num
        video_num_ = test_dataloader.dataset.video_num
        cut_off_points_ = [itm - 1 for itm in cut_off_points_]

    if multi_sentence_:
        logger.info("Eval under the multi-sentence per video clip setting.")
        logger.info("sentence num: {}, video num: {}".format(sentence_num_, video_num_))

    model.eval()
    with torch.no_grad():
        batch_list_t, batch_list_v = [], []
        batch_list_caption, batch_list_video_id = [], []

        total_video_num = 0
        # cache the features
        for bid, batch in enumerate(test_dataloader):
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, video, video_mask, \
            pairs_masked_text, pairs_token_labels, masked_video, video_labels_index, \
            pairs_input_caption_ids, pairs_decoder_mask, pairs_output_caption_ids, \
            pairs_input_video_id = batch

            sequence_output = model.get_sequence_output(input_ids, segment_ids, input_mask)
            visual_output = model.get_visual_output(video, video_mask)

            batch_list_t.append(sequence_output)
            batch_list_v.append(visual_output)

            batch_list_caption.append(pairs_input_caption_ids)
            batch_list_video_id.append(pairs_input_video_id)
            
            total_video_num += video.shape[0]

        # calculate the similarity
        if len(batch_list_t) > 0:
            batch_list_t = torch.cat(batch_list_t, dim=0)
            batch_list_v = torch.cat(batch_list_v, dim=0)

            if args.local_rank == 0:
                logger.info("total_video_num: {}".format(total_video_num))

            batch_list_caption = torch.cat(batch_list_caption, dim=0)
            batch_list_video_id = torch.cat(batch_list_video_id, dim=0)

            sim_matrix = model.get_similarity_logits(batch_list_t, batch_list_v, 
                                                    batch_list_caption, batch_list_video_id, 
                                                    loose_type=model.loose_type)
            sim_matrix = sim_matrix.cpu().numpy()
            
            if multi_sentence_:
                logger.info("before reshape, sim matrix size: {} x {}".format(sim_matrix.shape[0], sim_matrix.shape[1]))
                cut_off_points2len_ = [itm + 1 for itm in cut_off_points_]
                max_length = max([e_-s_ for s_, e_ in zip([0]+cut_off_points2len_[:-1], cut_off_points2len_)])
                sim_matrix_new = np.zeros([video_num_, max_length])
                sim_matrix_new[:, :] = np.nan
                for i in range(video_num_):
                    for j in range(cut_off_points2len_[i] - (cut_off_points2len_[i-1] if i > 0 else 0)):
                        sim_matrix_new[i, j] = sim_matrix[i, (cut_off_points2len_[i-1] if i > 0 else 0) + j]
                sim_matrix = sim_matrix_new
                logger.info("after reshape, sim matrix size: {} x {}".format(sim_matrix.shape[0], sim_matrix.shape[1]))

            tv_metrics = compute_metrics(sim_matrix)
            vt_metrics = compute_metrics(sim_matrix.T)

        logger.info('\t Length-T: {}, Length-V:{}'.format(len(sim_matrix), len(sim_matrix[0])))

        logger.info("Text-to-Video:")
        logger.info('\t>>>  R@1: {:.1f} - R@5: {:.1f} - R@10: {:.1f} - Median R: {:.1f} - Mean R: {:.1f}'.
                    format(tv_metrics['R1'], tv_metrics['R5'], tv_metrics['R10'], tv_metrics['MR'], tv_metrics['MeanR']))
        logger.info("Video-to-Text:")
        logger.info('\t>>>  V2T$R@1: {:.1f} - V2T$R@5: {:.1f} - V2T$R@10: {:.1f} - V2T$Median R: {:.1f} - V2T$Mean R: {:.1f}'.
                    format(vt_metrics['R1'], vt_metrics['R5'], vt_metrics['R10'], vt_metrics['MR'], vt_metrics['MeanR']))

        R1 = tv_metrics['R1']
    return R1

def main():
    global logger
    args = get_args()
    args = set_seed_logger(args)
    device, n_gpu = init_device(args, args.local_rank)

    tokenizer = ClipTokenizer()

    model = load_model(args, n_gpu, device)

    # dataloader loading
    assert args.datatype in SIMPLE_DATALOADER_DICT

    assert SIMPLE_DATALOADER_DICT[args.datatype]["test"] is not None \
           or SIMPLE_DATALOADER_DICT[args.datatype]["val"] is not None

    test_dataloader, test_length = None, 0
    if SIMPLE_DATALOADER_DICT[args.datatype]["test"] is not None:
        test_dataloader, test_length = SIMPLE_DATALOADER_DICT[args.datatype]["test"](args, tokenizer)

    if SIMPLE_DATALOADER_DICT[args.datatype]["val"] is not None:
        val_dataloader, val_length = SIMPLE_DATALOADER_DICT[args.datatype]["val"](args, tokenizer)
        if test_dataloader is None:
            test_dataloader, test_length = val_dataloader, val_length

    if args.local_rank == 0:
        logger.info("***** Running test *****")
        logger.info("  Num examples = %d", test_length)
        logger.info("  Batch size = %d", args.batch_size_val)
        logger.info("  Num steps = %d", len(test_dataloader))

    if args.do_eval:
        eval_result = eval_epoch(args, model, test_dataloader, device, n_gpu)
        logger.info("Final R@1: %f", eval_result)

if __name__ == "__main__":
    main()