test / cache_main_task_retrieval.py
jaewooo's picture
Initial upload
de15dc5 verified
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
import torch
import torch.nn.functional as F
import numpy as np
import random
import os
from metrics import compute_metrics, tensor_text_to_video_metrics, tensor_video_to_text_sim
import time
import argparse
from modules.tokenization_clip import SimpleTokenizer as ClipTokenizer
from modules.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from modules.modeling import CLIP4Clip
import matplotlib.pyplot as plt
from modules.optimization import BertAdam
from util import parallel_apply, get_logger
from modules.until_module import AllGather
from dataloaders.data_dataloaders import DATALOADER_DICT
# torch.distributed.init_process_group(backend="nccl")
global logger
def get_args(description='CLIP4Clip on Retrieval Task'):
parser = argparse.ArgumentParser(description=description)
parser.add_argument("--do_pretrain", action='store_true', help="Whether to run training.")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument('--train_csv', type=str, default='data/.train.csv', help='')
parser.add_argument('--val_csv', type=str, default='data/.val.csv', help='')
parser.add_argument('--data_path', type=str, default='data/caption.pickle', help='data pickle file path')
parser.add_argument('--features_path', type=str, default='data/videos_feature.pickle', help='feature path')
parser.add_argument('--num_thread_reader', type=int, default=1, help='')
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=256, help='batch size')
parser.add_argument('--batch_size_val', type=int, default=3500, help='batch size eval')
parser.add_argument('--lr_decay', type=float, default=0.9, help='Learning rate exp epoch decay')
parser.add_argument('--n_display', type=int, default=100, help='Information display frequence')
parser.add_argument('--video_dim', type=int, default=1024, help='video feature dimension')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--max_words', type=int, default=20, help='')
parser.add_argument('--max_frames', type=int, default=100, help='')
parser.add_argument('--feature_framerate', type=int, default=1, help='')
parser.add_argument('--margin', type=float, default=0.1, help='margin for loss')
parser.add_argument('--hard_negative_rate', type=float, default=0.5, help='rate of intra negative sample')
parser.add_argument('--negative_weighting', type=int, default=1, help='Weight the loss for intra negative')
parser.add_argument('--n_pair', type=int, default=1, help='Num of pair to output from data loader')
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--cross_model", default="cross-base", type=str, required=False, help="Cross module")
parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
parser.add_argument("--resume_model", default=None, type=str, required=False, help="Resume train model.")
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--n_gpu', type=int, default=1, help="Changed in the execute process.")
parser.add_argument("--cache_dir", default="", type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument("--task_type", default="retrieval", type=str, help="Point the task `retrieval` to finetune.")
parser.add_argument("--datatype", default="msrvtt", type=str, help="Point the dataset to finetune.")
parser.add_argument("--world_size", default=0, type=int, help="distribted training")
parser.add_argument("--local_rank", default=0, type=int, help="distribted training")
# alias for torch.distributed.run / launch passing --local-rank
parser.add_argument("--local-rank", dest="local_rank", default=0, type=int, help="alias for local_rank")
parser.add_argument("--rank", default=0, type=int, help="distribted training")
parser.add_argument('--coef_lr', type=float, default=1., help='coefficient for bert branch.')
parser.add_argument('--use_mil', action='store_true', help="Whether use MIL as Miech et. al. (2020).")
parser.add_argument('--sampled_use_mil', action='store_true', help="Whether MIL, has a high priority than use_mil.")
parser.add_argument('--text_num_hidden_layers', type=int, default=12, help="Layer NO. of text.")
parser.add_argument('--visual_num_hidden_layers', type=int, default=12, help="Layer NO. of visual.")
parser.add_argument('--cross_num_hidden_layers', type=int, default=4, help="Layer NO. of cross.")
parser.add_argument('--loose_type', action='store_true', help="Default using tight type for retrieval.")
parser.add_argument('--expand_msrvtt_sentences', action='store_true', help="")
parser.add_argument('--train_frame_order', type=int, default=0, choices=[0, 1, 2],
help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.")
parser.add_argument('--eval_frame_order', type=int, default=0, choices=[0, 1, 2],
help="Frame order, 0: ordinary order; 1: reverse order; 2: random order.")
parser.add_argument('--freeze_layer_num', type=int, default=0, help="Layer NO. of CLIP need to freeze.")
parser.add_argument('--slice_framepos', type=int, default=0, choices=[0, 1, 2],
help="0: cut from head frames; 1: cut from tail frames; 2: extract frames uniformly.")
parser.add_argument('--linear_patch', type=str, default="2d", choices=["2d", "3d"],
help="linear projection of flattened patches.")
parser.add_argument('--sim_header', type=str, default="meanP",
choices=["meanP", "seqLSTM", "seqTransf", "tightTransf"],
help="choice a similarity header.")
parser.add_argument("--pretrained_clip_name", default="ViT-B/32", type=str, help="Choose a CLIP version")
parser.add_argument("--use_rff", action='store_true', help="Use RFF hypervector encoding for video embeddings")
parser.add_argument("--rff_dim", type=int, default=3000, help="Hypervector dimension for RFF encoding")
parser.add_argument("--use_clip4hashing", action="store_true", help="CLIP4Hashing 손실·해시 경로 사용 여부")
parser.add_argument("--hash_bit", type=int, default=2048, help="해시 코드 비트 수 (default 1024)")
# Projection options
parser.add_argument('--proj', type=int, default=0, help='Projection dim (0 to disable, e.g., 3008)')
parser.add_argument('--proj_act', type=str, default='tanh', choices=['tanh', 'relu', 'gelu', 'sigmoid'],
help='Activation after projection')
parser.add_argument('--binary_eval', action='store_true', help='Use binarized retrieval at eval (sign + sum)')
args = parser.parse_args()
if args.sim_header == "tightTransf":
args.loose_type = False
# Check paramenters
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
args.batch_size = int(args.batch_size / args.gradient_accumulation_steps)
# Accept env fallback if provided by torchrun
if 'LOCAL_RANK' in os.environ:
try:
args.local_rank = int(os.environ['LOCAL_RANK'])
except Exception:
pass
if 'RANK' in os.environ:
try:
args.rank = int(os.environ['RANK'])
except Exception:
pass
if 'WORLD_SIZE' in os.environ:
try:
args.world_size = int(os.environ['WORLD_SIZE'])
except Exception:
pass
return args
def set_seed_logger(args):
global logger
# predefining random initial seeds
random.seed(args.seed)
os.environ['PYTHONHASHSEED'] = str(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(args.local_rank)
args.world_size = world_size
rank = torch.distributed.get_rank()
args.rank = rank
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir, exist_ok=True)
logger = get_logger(os.path.join(args.output_dir, "log.txt"))
if args.local_rank == 0:
logger.info("Effective parameters:")
for key in sorted(args.__dict__):
logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
# Sanity check for binary eval + bit packing compatibility
if getattr(args, 'binary_eval', False) and getattr(args, 'proj', 0) > 0:
if args.proj % 64 != 0:
raise ValueError(f"--proj must be divisible by 64 for binary eval, got {args.proj}")
return args
def init_device(args, local_rank):
global logger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu", local_rank)
n_gpu = torch.cuda.device_count()
logger.info("device: {} n_gpu: {}".format(device, n_gpu))
args.n_gpu = n_gpu
if args.batch_size % args.n_gpu != 0 or args.batch_size_val % args.n_gpu != 0:
raise ValueError("Invalid batch_size/batch_size_val and n_gpu parameter: {}%{} and {}%{}, should be == 0".format(
args.batch_size, args.n_gpu, args.batch_size_val, args.n_gpu))
return device, n_gpu
def init_model(args, device, n_gpu, local_rank):
if args.init_model:
model_state_dict = torch.load(args.init_model, map_location='cpu')
else:
model_state_dict = None
# Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = CLIP4Clip.from_pretrained(args.cross_model, cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
# Attach projection head if requested (before DDP wrapping)
if getattr(args, 'proj', 0) and args.proj > 0:
print("Projection")
# Register projection layer on the model
proj = torch.nn.Linear(512, args.proj, bias=False)
torch.nn.init.normal_(proj.weight, mean=0.0, std=0.02)
# Activation
if args.proj_act == 'tanh':
act = torch.nn.Tanh()
elif args.proj_act == 'relu':
act = torch.nn.ReLU()
elif args.proj_act == 'gelu':
act = torch.nn.GELU()
elif args.proj_act == 'sigmoid':
act = torch.nn.Sigmoid()
else:
act = torch.nn.Tanh()
model.proj_head = proj
model.proj_activation = act
# If init_model contains proj_head params, load them now
if model_state_dict is not None:
try:
missing, unexpected = model.load_state_dict(model_state_dict, strict=False)
except Exception:
pass
model.to(device)
return model
def prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, local_rank, coef_lr=1.):
if hasattr(model, 'module'):
model = model.module
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
decay_param_tp = [(n, p) for n, p in param_optimizer if not any(nd in n for nd in no_decay)]
no_decay_param_tp = [(n, p) for n, p in param_optimizer if any(nd in n for nd in no_decay)]
decay_clip_param_tp = [(n, p) for n, p in decay_param_tp if "clip." in n]
decay_noclip_param_tp = [(n, p) for n, p in decay_param_tp if "clip." not in n]
no_decay_clip_param_tp = [(n, p) for n, p in no_decay_param_tp if "clip." in n]
no_decay_noclip_param_tp = [(n, p) for n, p in no_decay_param_tp if "clip." not in n]
weight_decay = 0.2
optimizer_grouped_parameters = [
{'params': [p for n, p in decay_clip_param_tp], 'weight_decay': weight_decay, 'lr': args.lr * coef_lr},
{'params': [p for n, p in decay_noclip_param_tp], 'weight_decay': weight_decay},
{'params': [p for n, p in no_decay_clip_param_tp], 'weight_decay': 0.0, 'lr': args.lr * coef_lr},
{'params': [p for n, p in no_decay_noclip_param_tp], 'weight_decay': 0.0}
]
scheduler = None
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=args.warmup_proportion,
schedule='warmup_cosine', b1=0.9, b2=0.98, e=1e-6,
t_total=num_train_optimization_steps, weight_decay=weight_decay,
max_grad_norm=1.0)
# 옵티마이저 만든 뒤 곧장 실행
name2param = {n: p for n, p in model.named_parameters() if p.requires_grad}
param2name = {id(p): n for n, p in name2param.items()}
for gi, g in enumerate(optimizer.param_groups):
print(f"[group {gi}] lr={g['lr']:.2e}, params={len(g['params'])}")
# 각 그룹에서 몇 개만 샘플로 찍기
for p in g["params"][:8]:
print(" ", param2name.get(id(p), "?"))
# Ensure both ranks finish building the exact same model before DDP wrap.
# Use a CPU barrier to avoid NCCL device-scoped hangs/timeouts.
# Quick debug: log param tensor count per rank
try:
num_tensors = len(list(model.parameters()))
if local_rank == 0:
print(f"[DDP-DEBUG] rank={args.rank} local_rank={local_rank} param_tensors={num_tensors}")
else:
print(f"[DDP-DEBUG] rank={args.rank} local_rank={local_rank} param_tensors={num_tensors}")
except Exception:
pass
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
output_device=local_rank, find_unused_parameters=False)
return optimizer, scheduler, model
def save_model(epoch, args, model, optimizer, tr_loss, type_name=""):
# Only save the model it-self
model_to_save = model.module if hasattr(model, 'module') else model
output_model_file = os.path.join(
args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
optimizer_state_file = os.path.join(
args.output_dir, "pytorch_opt.bin.{}{}".format("" if type_name=="" else type_name+".", epoch))
torch.save(model_to_save.state_dict(), output_model_file)
torch.save({
'epoch': epoch,
'optimizer_state_dict': optimizer.state_dict(),
'loss': tr_loss,
}, optimizer_state_file)
logger.info("Model saved to %s", output_model_file)
logger.info("Optimizer saved to %s", optimizer_state_file)
return output_model_file
def load_model(epoch, args, n_gpu, device, model_file=None):
if model_file is None or len(model_file) == 0:
model_file = os.path.join(args.output_dir, "pytorch_model.bin.{}".format(epoch))
if os.path.exists(model_file):
model_state_dict = torch.load(model_file, map_location='cpu')
if args.local_rank == 0:
logger.info("Model loaded from %s", model_file)
# Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed')
model = CLIP4Clip.from_pretrained(args.cross_model, cache_dir=cache_dir, state_dict=model_state_dict, task_config=args)
# Attach projection head if needed and load any matching weights
if getattr(args, 'proj', 0) and args.proj > 0:
proj = torch.nn.Linear(512, args.proj, bias=False)
torch.nn.init.normal_(proj.weight, mean=0.0, std=0.02)
if args.proj_act == 'tanh':
act = torch.nn.Tanh()
elif args.proj_act == 'relu':
act = torch.nn.ReLU()
elif args.proj_act == 'gelu':
act = torch.nn.GELU()
elif args.proj_act == 'sigmoid':
act = torch.nn.Sigmoid()
else:
act = torch.nn.Tanh()
model.proj_head = proj
model.proj_activation = act
try:
model.load_state_dict(model_state_dict, strict=False)
except Exception:
pass
model.to(device)
else:
model = None
logger.info(f"모델을 로드합니다:{cache_dir}")
return model
def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, scheduler, global_step, local_rank=0):
global logger
torch.cuda.empty_cache()
net = model.module if hasattr(model, 'module') else model
net.train()
log_step = args.n_display
start_time = time.time()
total_loss = 0
for step, batch in enumerate(train_dataloader):
if n_gpu == 1:
# multi-gpu does scattering it-self
batch = tuple(t.to(device=device, non_blocking=True) for t in batch)
input_ids, input_mask, segment_ids, video, video_mask = batch
# If projection head enabled, override loss path to use proj+tanh->sum embeddings
if getattr(args, 'proj', 0) and args.proj > 0 and hasattr(net, 'proj_head'):
# Forward backbone encoders only
sequence_output, visual_output = net.get_sequence_visual_output(
input_ids, segment_ids, input_mask, video, video_mask)
# Gather across processes for global negatives
sequence_output = AllGather.apply(sequence_output, args)
visual_output = AllGather.apply(visual_output, args)
video_mask_g = AllGather.apply(video_mask.view(-1, video_mask.shape[-1]), args)
# Text: [B,1,512] -> [B,512] -> proj -> act -> L2
txt = sequence_output.squeeze(1)
txt = net.proj_activation(net.proj_head(txt))
txt = F.normalize(txt, dim=-1)
# Video: [B,T,512] -> proj -> act -> mask -> sum(T) -> L2
vid = net.proj_activation(net.proj_head(visual_output))
vm = video_mask_g.to(dtype=vid.dtype).unsqueeze(-1)
vid = (vid * vm).sum(dim=1)
vid = F.normalize(vid, dim=-1)
logit_scale = net.clip.logit_scale.exp()
sim = logit_scale * torch.matmul(txt, vid.t())
# Same symmetric CE loss as original
loss = net.loss_fct(sim) + net.loss_fct(sim.T)
loss = loss * 0.5
else:
loss = model(input_ids, segment_ids, input_mask, video, video_mask)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
loss.backward()
total_loss += float(loss)
if (step + 1) % args.gradient_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
if scheduler is not None:
scheduler.step() # Update learning rate schedule
optimizer.step()
optimizer.zero_grad()
# https://github.com/openai/CLIP/issues/46
torch.clamp_(net.clip.logit_scale.data, max=np.log(100))
global_step += 1
if global_step % log_step == 0 and local_rank == 0:
logger.info("Epoch: %d/%s, Step: %d/%d, Lr: %s, Loss: %f, Time/step: %f", epoch + 1,
args.epochs, step + 1,
len(train_dataloader), "-".join([str('%.9f'%itm) for itm in sorted(list(set(optimizer.get_lr())))]),
float(loss),
(time.time() - start_time) / (log_step * args.gradient_accumulation_steps))
start_time = time.time()
total_loss = total_loss / len(train_dataloader)
return total_loss, global_step
def _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_sequence_output_list, batch_visual_output_list):
sim_matrix = []
for idx1, b1 in enumerate(batch_list_t):
input_mask, segment_ids, *_tmp = b1
sequence_output = batch_sequence_output_list[idx1]
each_row = []
for idx2, b2 in enumerate(batch_list_v):
video_mask, *_tmp = b2
visual_output = batch_visual_output_list[idx2]
b1b2_logits, *_tmp = model.get_similarity_logits(sequence_output, visual_output, input_mask, video_mask,
loose_type=model.loose_type)
b1b2_logits = b1b2_logits.cpu().detach().numpy()
each_row.append(b1b2_logits)
each_row = np.concatenate(tuple(each_row), axis=-1)
sim_matrix.append(each_row)
return sim_matrix
def eval_epoch(args, model, test_dataloader, device, n_gpu, tokenizer):
import numpy as np
import os
import torch
import matplotlib.pyplot as plt
def _decode_query(tokenizer, ids_tensor):
# ids_tensor: 1D tensor (token ids)
try:
if isinstance(ids_tensor, torch.Tensor):
ids = ids_tensor.cpu().numpy().tolist()
else:
ids = ids_tensor.tolist() if hasattr(ids_tensor, 'tolist') else list(ids_tensor)
# ClipTokenizer의 특수 토큰 ID들
start_token_id = tokenizer.encoder.get('<|startoftext|>', 49406) # 기본값
end_token_id = tokenizer.encoder.get('<|endoftext|>', 49407) # 기본값
# 패딩 토큰(0)과 특수 토큰 제거
clean_ids = []
for token_id in ids:
if token_id > 0 and token_id != start_token_id and token_id != end_token_id:
clean_ids.append(token_id)
if not clean_ids:
return "<empty_query>"
# 유효하지 않은 토큰 ID 필터링 (vocab 범위 내)
vocab_size = len(tokenizer.decoder)
valid_ids = [tid for tid in clean_ids if tid < vocab_size]
if not valid_ids:
return "<invalid_tokens>"
# 디코딩 시도
try:
decoded_text = tokenizer.decode(valid_ids)
return decoded_text.strip()
except KeyError as e:
# 개별 토큰별로 디코딩 시도
decoded_tokens = []
for tid in valid_ids:
if tid in tokenizer.decoder:
decoded_tokens.append(tokenizer.decoder[tid])
else:
decoded_tokens.append(f"<unk_{tid}>")
text = ''.join(decoded_tokens)
# BPE 후처리
text = text.replace('</w>', ' ').strip()
return text if text else "<decode_partial_error>"
except Exception as e:
return f"<decode_error: {str(e)[:50]}>"
except Exception as e:
return f"<general_error: {str(e)[:50]}>"
def _get_video_ids_from_dataset(dataset, num_videos):
# 다양한 후보 속성명 시도 → 없으면 0..N-1 인덱스 문자열로 대체
for attr in ["video_list", "video_ids", "video_names", "videos", "vids", "id_list"]:
if hasattr(dataset, attr):
obj = getattr(dataset, attr)
try:
if isinstance(obj, (list, tuple)) and len(obj) == num_videos:
return list(map(str, obj))
except Exception:
pass
return [str(i) for i in range(num_videos)]
global logger
multi_sentence_ = False
cut_off_points_, sentence_num_, video_num_ = [], -1, -1
if hasattr(model, 'module'):
model = model.module.to(device)
else:
model = model.to(device)
if hasattr(model, 'module'):
model.module.eval()
else:
model.eval()
logger.info("Model %s", "training" if model.training else "eval")
# suffix for cache/result naming
suffix = "_hash" if getattr(args, "use_clip4hashing", False) else ""
suffix += "_rff" if args.use_rff else ""
suffix += f"_proj{args.proj}" if getattr(args, 'proj', 0) and args.proj > 0 else ""
suffix += "_binary" if getattr(args, 'binary_eval', False) else ""
suffix += "_trained" if args.init_model else ""
# (A) 캐시 로드/생성
if "train" in args.val_csv and "10k" in args.val_csv:
cache_name = f"{args.datatype}_train_test_10k_cache{suffix}.pt"
logger.info(f"9k 훈련 데이터 캐시 생성: {cache_name}")
else:
cache_name = f"{args.datatype}_eval_cache{suffix}.pt"
logger.info(f"평가 데이터 캐시: {cache_name}")
cache_path = os.path.join(args.output_dir, cache_name)
loaded_from_cache = False
if os.path.exists(cache_path):
logger.info(f"캐시된 피처를 로드합니다: {cache_path}")
cache = torch.load(cache_path, map_location=device)
batch_sequence_output_list = cache['batch_sequence_output_list']
batch_visual_output_list = cache['batch_visual_output_list']
batch_list_t = cache['batch_list_t']
batch_list_v = cache['batch_list_v']
text_input_ids_list = cache.get('text_input_ids_list', None)
video_ids = cache.get('video_ids', None)
loaded_from_cache = True
logger.info(f"[Cache] 텍스트 피쳐 개수={len(batch_sequence_output_list)} "
f"각 텐서 shape={batch_sequence_output_list[0].shape}")
logger.info(f"[Cache] 비디오 피쳐 개수={len(batch_visual_output_list)} "
f"각 텐서 shape={batch_visual_output_list[0].shape}")
else:
print("Caching feature..")
if hasattr(test_dataloader.dataset, 'multi_sentence_per_video') and \
test_dataloader.dataset.multi_sentence_per_video:
multi_sentence_ = True
cut_off_points_ = test_dataloader.dataset.cut_off_points
sentence_num_ = test_dataloader.dataset.sentence_num
video_num_ = test_dataloader.dataset.video_num
cut_off_points_ = [itm - 1 for itm in cut_off_points_]
logger.warning("Eval under multi-sentence-per-video. sentence num: %s, video num: %s",
sentence_num_, video_num_)
with torch.no_grad():
batch_list_t = []
batch_list_v = []
batch_sequence_output_list, batch_visual_output_list = [], []
text_input_ids_list = []
total_video_num = 0
for bid, batch in enumerate(test_dataloader):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, video, video_mask = batch
if multi_sentence_:
b, *_t = video.shape
sequence_output = model.get_sequence_output(input_ids, segment_ids, input_mask)
batch_sequence_output_list.append(sequence_output)
# input_ids를 함께 보관 (run_on_single_gpu는 *_tmp로 무시하므로 안전)
batch_list_t.append((input_mask, segment_ids, input_ids.detach().cpu()))
s_, e_ = total_video_num, total_video_num + b
filter_inds = [itm - s_ for itm in cut_off_points_ if s_ <= itm < e_]
if len(filter_inds) > 0:
video, video_mask = video[filter_inds, ...], video_mask[filter_inds, ...]
visual_output = model.get_visual_output(video, video_mask)
batch_visual_output_list.append(visual_output)
batch_list_v.append((video_mask,))
total_video_num += b
else:
sequence_output, visual_output = model.get_sequence_visual_output(
input_ids, segment_ids, input_mask, video, video_mask)
batch_sequence_output_list.append(sequence_output)
batch_list_t.append((input_mask, segment_ids, input_ids.detach().cpu()))
batch_visual_output_list.append(visual_output)
batch_list_v.append((video_mask,))
print("{}/{}\r".format(bid, len(test_dataloader)), end="")
# 비디오 ID 목록 구성 (데이터셋 노출 없으면 0..N-1)
num_videos = int(sum(bv.shape[0] for bv in batch_visual_output_list))
video_ids = _get_video_ids_from_dataset(test_dataloader.dataset, num_videos)
logger.info(f"추출된 피처를 캐시에 저장합니다: {cache_path}")
torch.save({
'batch_sequence_output_list': batch_sequence_output_list,
'batch_visual_output_list': batch_visual_output_list,
'batch_list_t': batch_list_t,
'batch_list_v': batch_list_v,
'text_input_ids_list': text_input_ids_list,
'video_ids': video_ids,
}, cache_path)
logger.info(f"[Cache] 텍스트 피쳐 개수={len(batch_sequence_output_list)} "
f"각 텐서 shape={batch_sequence_output_list[0].shape}")
logger.info(f"[Cache] 비디오 피쳐 개수={len(batch_visual_output_list)} "
f"각 텐서 shape={batch_visual_output_list[0].shape}")
# 캐시에 text_input_ids_list가 없으면, 한 번 더 훑어서 수집 (구버전 캐시 호환)
if loaded_from_cache and 'text_input_ids_list' not in cache:
logger.info("캐시에 text_input_ids_list가 없어 재수집합니다(호환성 경로).")
text_input_ids_list = []
with torch.no_grad():
for batch in test_dataloader:
input_ids = batch[0].detach().cpu()
text_input_ids_list.append(input_ids)
elif loaded_from_cache and text_input_ids_list is None:
# batch_list_t에서 input_ids 추출
logger.info("batch_list_t에서 input_ids를 추출합니다.")
text_input_ids_list = []
for input_mask, segment_ids, input_ids in batch_list_t:
text_input_ids_list.append(input_ids)
# video_ids가 없으면 만들어준다(구버전 캐시 호환)
if loaded_from_cache and cache.get('video_ids', None) is None:
num_videos = int(sum(bv.shape[0] for bv in batch_visual_output_list))
video_ids = _get_video_ids_from_dataset(test_dataloader.dataset, num_videos)
# (B) 유사도 행렬 계산
def _run_on_single_gpu(model, batch_list_t, batch_list_v, batch_sequence_output_list, batch_visual_output_list):
sim_matrix = []
use_proj = getattr(args, 'proj', 0) and args.proj > 0 and hasattr(model, 'proj_head')
use_binary = getattr(args, 'binary_eval', False) and use_proj
for idx1, b1 in enumerate(batch_list_t):
input_mask, segment_ids, *_tmp = b1
sequence_output = batch_sequence_output_list[idx1]
each_row = []
if use_proj:
# Text: [B,1,512] -> [B,512] -> proj -> act
t = sequence_output.squeeze(1)
t = model.proj_activation(model.proj_head(t))
if use_binary:
t_vec = torch.where(t > 0, torch.tensor(1.0, device=t.device, dtype=t.dtype),
torch.tensor(-1.0, device=t.device, dtype=t.dtype))
else:
t_vec = F.normalize(t, dim=-1)
for idx2, b2 in enumerate(batch_list_v):
video_mask, *_tmp = b2
visual_output = batch_visual_output_list[idx2]
v = model.proj_activation(model.proj_head(visual_output)) # [B, T, P]
# Robust mask handling: ensure vm is [B, T, 1]
vm = video_mask
# Squeeze stray singleton dims (e.g., [B, T, 1] -> [B, T])
while vm.dim() > 2 and vm.size(-1) == 1:
vm = vm.squeeze(-1)
# If still higher-rank, flatten to [B, -1] safely
if vm.dim() > 2:
vm = vm.view(vm.size(0), -1)
# If [T] vector sneaks in, expand batch
if vm.dim() == 1:
vm = vm.unsqueeze(0)
# Align time length to visual features
if vm.size(1) != v.size(1):
vm = vm[:, :v.size(1)]
vm = vm.to(dtype=v.dtype).unsqueeze(-1) # [B, T, 1]
# Masked temporal sum -> [B, P]
v = (v * vm).sum(dim=1)
# Squeeze any trailing singleton dim that might remain
if v.dim() > 2 and v.size(-1) == 1:
v = v.squeeze(-1)
if use_binary:
v_vec = torch.where(v > 0, torch.tensor(1.0, device=v.device, dtype=v.dtype),
torch.tensor(-1.0, device=v.device, dtype=v.dtype))
scores = torch.matmul(t_vec, v_vec.t())
else:
v_vec = F.normalize(v, dim=-1)
scores = torch.matmul(t_vec, v_vec.t())
each_row.append(scores.cpu().detach().numpy())
else:
for idx2, b2 in enumerate(batch_list_v):
video_mask, *_tmp = b2
visual_output = batch_visual_output_list[idx2]
b1b2_logits, *_tmp = model.get_similarity_logits(
sequence_output, visual_output, input_mask, video_mask, loose_type=model.loose_type)
b1b2_logits = b1b2_logits.cpu().detach().numpy()
each_row.append(b1b2_logits)
each_row = np.concatenate(tuple(each_row), axis=-1)
sim_matrix.append(each_row)
return sim_matrix
if n_gpu > 1:
device_ids = list(range(n_gpu))
batch_list_t_splits, batch_list_v_splits = [], []
batch_t_output_splits, batch_v_output_splits = [], []
bacth_len = len(batch_list_t)
split_len = (bacth_len + n_gpu - 1) // n_gpu
for dev_id in device_ids:
s_, e_ = dev_id * split_len, (dev_id + 1) * split_len
if dev_id == 0:
batch_list_t_splits.append(batch_list_t[s_:e_]); batch_list_v_splits.append(batch_list_v)
batch_t_output_splits.append(batch_sequence_output_list[s_:e_]); batch_v_output_splits.append(batch_visual_output_list)
else:
devc = torch.device(f'cuda:{dev_id}')
devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list_t[s_:e_]]
batch_list_t_splits.append(devc_batch_list)
devc_batch_list = [tuple(t.to(devc) for t in b) for b in batch_list_v]
batch_list_v_splits.append(devc_batch_list)
devc_batch_list = [b.to(devc) for b in batch_sequence_output_list[s_:e_]]
batch_t_output_splits.append(devc_batch_list)
devc_batch_list = [b.to(devc) for b in batch_visual_output_list]
batch_v_output_splits.append(devc_batch_list)
parameters_tuple_list = [(batch_list_t_splits[dev_id], batch_list_v_splits[dev_id],
batch_t_output_splits[dev_id], batch_v_output_splits[dev_id]) for dev_id in device_ids]
parallel_outputs = parallel_apply(_run_on_single_gpu, model, parameters_tuple_list, device_ids)
sim_matrix = []
for idx in range(len(parallel_outputs)):
sim_matrix += parallel_outputs[idx]
sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
else:
sim_matrix = _run_on_single_gpu(model, batch_list_t, batch_list_v,
batch_sequence_output_list, batch_visual_output_list)
sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
# (C) 멀티센텐스 처리 및 메트릭
if hasattr(test_dataloader.dataset, 'multi_sentence_per_video') and \
test_dataloader.dataset.multi_sentence_per_video:
multi_sentence_ = True
if multi_sentence_:
logger.info("before reshape, sim matrix size: %d x %d", sim_matrix.shape[0], sim_matrix.shape[1])
sim_matrix_flat = sim_matrix.copy() # 쿼리별 Top-K 용 2D 보관
cut_off_points2len_ = [itm + 1 for itm in cut_off_points_]
max_length = max([e_-s_ for s_, e_ in zip([0]+cut_off_points2len_[:-1], cut_off_points2len_)])
sim_matrix_new = []
for s_, e_ in zip([0] + cut_off_points2len_[:-1], cut_off_points2len_):
sim_matrix_new.append(np.concatenate((sim_matrix[s_:e_],
np.full((max_length-e_+s_, sim_matrix.shape[1]), -np.inf)), axis=0))
sim_matrix = np.stack(tuple(sim_matrix_new), axis=0)
logger.info("after reshape, sim matrix size: %d x %d x %d",
sim_matrix.shape[0], sim_matrix.shape[1], sim_matrix.shape[2])
tv_metrics = tensor_text_to_video_metrics(sim_matrix)
vt_metrics = compute_metrics(tensor_video_to_text_sim(sim_matrix))
else:
logger.info("sim matrix size: %d x %d", sim_matrix.shape[0], sim_matrix.shape[1])
# 히트맵 저장(샘플)
plt.figure(figsize=(8,6))
plt.imshow(sim_matrix[:100, :100], aspect='auto')
plt.title('Similarity Matrix Heatmap')
plt.xlabel('Video Index')
plt.ylabel('Text Index')
plt.tight_layout()
out_path = os.path.join(args.output_dir, 'sim_matrix_heatmap.png')
plt.savefig(out_path); plt.close()
logger.info(f"Saved sim_matrix heatmap to {out_path}")
sim_matrix_flat = sim_matrix # 2D 그대로
tv_metrics = compute_metrics(sim_matrix)
vt_metrics = compute_metrics(sim_matrix.T)
logger.info('\t Length-T: %d, Length-V:%d', len(sim_matrix), len(sim_matrix[0]))
logger.info("Text-to-Video:")
logger.info('\t>>> R@1: %.1f - R@5: %.1f - R@10: %.1f - Median R: %.1f - Mean R: %.1f',
tv_metrics['R1'], tv_metrics['R5'], tv_metrics['R10'], tv_metrics['MR'], tv_metrics['MeanR'])
logger.info("Video-to-Text:")
logger.info('\t>>> V2T$R@1: %.1f - V2T$R@5: %.1f - V2T$R@10: %.1f - V2T$Median R: %.1f - V2T$Mean R: %.1f',
vt_metrics['R1'], vt_metrics['R5'], vt_metrics['R10'], vt_metrics['MR'], vt_metrics['MeanR'])
# (D) 쿼리 텍스트 복원 + Top-10 덤프
# text_input_ids_list: List[Tensor[B_i, L]]
all_queries = []
logger.info(f"text_input_ids_list 개수: {len(text_input_ids_list)}")
for batch_idx, ids_batch in enumerate(text_input_ids_list):
if ids_batch is None:
logger.warning(f"배치 {batch_idx}: ids_batch가 None입니다.")
continue
try:
ids_batch = ids_batch if isinstance(ids_batch, torch.Tensor) else torch.as_tensor(ids_batch)
logger.info(f"배치 {batch_idx}: shape={ids_batch.shape}")
for row_idx, row in enumerate(ids_batch):
decoded = _decode_query(tokenizer, row)
all_queries.append(decoded)
if batch_idx == 0 and row_idx < 3: # 첫 배치의 처음 3개만 샘플로 출력
logger.info(f"샘플 디코딩 결과 [{batch_idx}-{row_idx}]: '{decoded}'")
except Exception as e:
logger.error(f"배치 {batch_idx} 처리 중 오류: {str(e)}")
# 에러가 발생해도 계속 진행
continue
logger.info(f"총 {len(all_queries)}개의 쿼리가 디코딩되었습니다.")
# video_ids 길이 보정(안전)
num_videos = sim_matrix_flat.shape[1]
if 'video_ids' in locals():
if len(video_ids) != num_videos:
logger.warning("video_ids 길이(%d)와 비디오 수(%d)가 달라 index로 대체합니다.",
len(video_ids), num_videos)
video_ids = [str(i) for i in range(num_videos)]
else:
video_ids = [str(i) for i in range(num_videos)]
# 저장 파일
topk = 10
out_tsv = os.path.join(args.output_dir, f"t2v_top10{suffix}.tsv")
out_json = os.path.join(args.output_dir, f"t2v_top10{suffix}.json")
if args.local_rank == 0:
import json
# TSV 파일 저장
with open(out_tsv, "w", encoding="utf-8") as f:
f.write("query_idx\tquery\tvideo_rank\tvideo_id\tvideo_idx\tscore\n")
for qi, q in enumerate(all_queries):
scores = sim_matrix_flat[qi]
# 효율: argpartition 후 정렬
idxs = np.argpartition(-scores, kth=min(topk, len(scores)-1))[:topk]
idxs = idxs[np.argsort(-scores[idxs])]
for rank, vidx in enumerate(idxs, 1):
f.write(f"{qi}\t{q}\t{rank}\t{video_ids[vidx]}\t{int(vidx)}\t{float(scores[vidx]):.6f}\n")
# JSON 파일 저장 (구조화된 형태)
results_dict = {}
for qi, q in enumerate(all_queries):
scores = sim_matrix_flat[qi]
idxs = np.argpartition(-scores, kth=min(topk, len(scores)-1))[:topk]
idxs = idxs[np.argsort(-scores[idxs])]
results_dict[f"query_{qi+1}"] = {
"query_text": q,
"top_videos": []
}
for rank, vidx in enumerate(idxs, 1):
results_dict[f"query_{qi+1}"]["top_videos"].append({
"rank": rank,
"video_id": video_ids[vidx],
"video_idx": int(vidx),
"score": float(scores[vidx])
})
with open(out_json, "w", encoding="utf-8") as f:
json.dump(results_dict, f, ensure_ascii=False, indent=2)
logger.info("T2V Top-10 per query 저장 완료:")
logger.info(" TSV 파일: %s", out_tsv)
logger.info(" JSON 파일: %s", out_json)
logger.info("총 %d개 쿼리에 대한 top-10 결과가 저장되었습니다.", len(all_queries))
# # 로그에 모든 쿼리의 Top-10 결과 출력
# logger.info("=== Query-wise Top-10 Results (전체 %d개 쿼리) ===", len(all_queries))
# for qi in range(len(all_queries)):
# scores = sim_matrix_flat[qi]
# idxs = np.argpartition(-scores, kth=min(topk, len(scores)-1))[:topk]
# idxs = idxs[np.argsort(-scores[idxs])]
# logger.info(f"Query {qi+1}: \"{all_queries[qi]}\"")
# for rank, vidx in enumerate(idxs, 1):
# logger.info(f" Rank {rank}: video_id={video_ids[vidx]}, video_idx={vidx}, score={scores[vidx]:.6f}")
# logger.info("---")
# logger.info("=== 모든 쿼리 결과 출력 완료 ===")
return tv_metrics['R1']
def main():
global logger
args = get_args()
if "LOCAL_RANK" in os.environ:
try:
args.local_rank = int(os.environ["LOCAL_RANK"])
except Exception:
pass
torch.cuda.set_device(args.local_rank)
from datetime import timedelta
import torch.distributed as dist
dist.init_process_group(
backend="nccl",
init_method="env://",
timeout=timedelta(minutes=30)
)
args = set_seed_logger(args)
device, n_gpu = init_device(args, args.local_rank)
tokenizer = ClipTokenizer()
assert args.task_type == "retrieval"
model = init_model(args, device, n_gpu, args.local_rank)
## ####################################
# freeze testing
## ####################################
assert args.freeze_layer_num <= 12 and args.freeze_layer_num >= -1
if hasattr(model, "clip") and args.freeze_layer_num > -1:
for name, param in model.clip.named_parameters():
# top layers always need to train
if name.find("ln_final.") == 0 or name.find("text_projection") == 0 or name.find("logit_scale") == 0 \
or name.find("visual.ln_post.") == 0 or name.find("visual.proj") == 0:
continue # need to train
elif name.find("visual.transformer.resblocks.") == 0 or name.find("transformer.resblocks.") == 0:
layer_num = int(name.split(".resblocks.")[1].split(".")[0])
if layer_num >= args.freeze_layer_num:
continue # need to train
if args.linear_patch == "3d" and name.find("conv2."):
continue
else:
# paramenters which < freeze_layer_num will be freezed
param.requires_grad = False
## ####################################
# dataloader loading
## ####################################
assert args.datatype in DATALOADER_DICT
assert DATALOADER_DICT[args.datatype]["test"] is not None \
or DATALOADER_DICT[args.datatype]["val"] is not None
test_dataloader, test_length = None, 0
if DATALOADER_DICT[args.datatype]["test"] is not None:
test_dataloader, test_length = DATALOADER_DICT[args.datatype]["test"](args, tokenizer)
if DATALOADER_DICT[args.datatype]["val"] is not None:
val_dataloader, val_length = DATALOADER_DICT[args.datatype]["val"](args, tokenizer, subset="val")
else:
val_dataloader, val_length = test_dataloader, test_length
## report validation results if the ["test"] is None
if test_dataloader is None:
test_dataloader, test_length = val_dataloader, val_length
if args.local_rank == 0:
logger.info("***** Running test *****")
logger.info(" Num examples = %d", test_length)
logger.info(" Batch size = %d", args.batch_size_val)
logger.info(" Num steps = %d", len(test_dataloader))
logger.info("***** Running val *****")
logger.info(" Num examples = %d", val_length)
## ####################################
# train and eval
## ####################################
if args.do_train:
train_dataloader, train_length, train_sampler = DATALOADER_DICT[args.datatype]["train"](args, tokenizer)
num_train_optimization_steps = (int(len(train_dataloader) + args.gradient_accumulation_steps - 1)
/ args.gradient_accumulation_steps) * args.epochs
coef_lr = args.coef_lr
optimizer, scheduler, model = prep_optimizer(args, model, num_train_optimization_steps, device, n_gpu, args.local_rank, coef_lr=coef_lr)
if args.local_rank == 0:
logger.info("***** Running training *****")
logger.info(" Num examples = %d", train_length)
logger.info(" Batch size = %d", args.batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps * args.gradient_accumulation_steps)
best_score = 0.00001
best_output_model_file = "None"
## ##############################################################
# resume optimizer state besides loss to continue train
## ##############################################################
resumed_epoch = 0
if args.resume_model:
checkpoint = torch.load(args.resume_model, map_location='cpu')
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
resumed_epoch = checkpoint['epoch']+1
resumed_loss = checkpoint['loss']
global_step = 0
for epoch in range(resumed_epoch, args.epochs):
train_sampler.set_epoch(epoch)
tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer,
scheduler, global_step, local_rank=args.local_rank)
if args.local_rank == 0:
logger.info("Epoch %d/%s Finished, Train Loss: %f", epoch + 1, args.epochs, tr_loss)
output_model_file = save_model(epoch, args, model, optimizer, tr_loss, type_name="")
## Run on val dataset, this process is *TIME-consuming*.
# logger.info("Eval on val dataset")
# R1 = eval_epoch(args, model, val_dataloader, device, n_gpu)
R1 = eval_epoch(args, model, test_dataloader, device, n_gpu, tokenizer)
if best_score <= R1:
best_score = R1
best_output_model_file = output_model_file
logger.info("The best model is: {}, the R1 is: {:.4f}".format(best_output_model_file, best_score))
## Uncomment if want to test on the best checkpoint
# if args.local_rank == 0:
# model = load_model(-1, args, n_gpu, device, model_file=best_output_model_file)
# eval_epoch(args, model, test_dataloader, device, n_gpu)
elif args.do_eval:
if args.local_rank == 0:
eval_epoch(args, model, test_dataloader, device, n_gpu, tokenizer)
if __name__ == "__main__":
main()