| | import json |
| | import logging |
| | import math |
| | import os |
| | import time |
| | from contextlib import suppress |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | try: |
| | import wandb |
| | except ImportError: |
| | wandb = None |
| |
|
| | from open_clip import ClipLoss, gather_features |
| | from .distributed import is_master |
| | from .zero_shot import zero_shot_eval |
| |
|
| |
|
| | class AverageMeter(object): |
| | """Computes and stores the average and current value""" |
| |
|
| | def __init__(self): |
| | self.reset() |
| |
|
| | def reset(self): |
| | self.val = 0 |
| | self.avg = 0 |
| | self.sum = 0 |
| | self.count = 0 |
| |
|
| | def update(self, val, n=1): |
| | self.val = val |
| | self.sum += val * n |
| | self.count += n |
| | self.avg = self.sum / self.count |
| |
|
| |
|
| | def unwrap_model(model): |
| | if hasattr(model, "module"): |
| | return model.module |
| | else: |
| | return model |
| |
|
| |
|
| | def train_one_epoch( |
| | model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None |
| | ): |
| | device = torch.device(args.device) |
| | autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress |
| | model.train() |
| | loss = ClipLoss( |
| | local_loss=args.local_loss, |
| | gather_with_grad=args.gather_with_grad, |
| | cache_labels=True, |
| | rank=args.rank, |
| | world_size=args.world_size, |
| | use_horovod=args.horovod, |
| | mlp_loss=args.clap_mlploss, |
| | weight_loss_kappa=args.kappa, |
| | ) |
| |
|
| | dataloader, sampler = data["train"].dataloader, data["train"].sampler |
| | if args.distributed and sampler is not None: |
| | sampler.set_epoch(epoch) |
| | num_batches_per_epoch = dataloader.num_batches |
| | sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) |
| |
|
| | |
| | if args.dataset_type == "toy": |
| | dataloader.dataset.generate_queue() |
| |
|
| | loss_m = AverageMeter() |
| | batch_time_m = AverageMeter() |
| | data_time_m = AverageMeter() |
| | end = time.time() |
| |
|
| | for i, batch in enumerate(dataloader): |
| | |
| | step = num_batches_per_epoch * epoch + i |
| | if isinstance(scheduler, dict): |
| | for s in scheduler.values(): |
| | s(step) |
| | else: |
| | scheduler(step) |
| | audios = batch |
| | texts = batch["text"] |
| | |
| | |
| |
|
| | data_time_m.update(time.time() - end) |
| | if isinstance(optimizer, dict): |
| | for o_ in optimizer.values(): |
| | o_.zero_grad() |
| | else: |
| | optimizer.zero_grad() |
| |
|
| | with autocast(): |
| | ( |
| | audio_features, |
| | text_features, |
| | audio_features_mlp, |
| | text_features_mlp, |
| | logit_scale_a, |
| | logit_scale_t, |
| | ) = model(audios, texts, device) |
| |
|
| | if args.clap_mlploss: |
| | total_loss = loss( |
| | audio_features=audio_features, |
| | text_features=text_features, |
| | logit_scale_a=logit_scale_a, |
| | logit_scale_t=logit_scale_t, |
| | audio_features_mlp=audio_features_mlp, |
| | text_features_mlp=text_features_mlp, |
| | ) |
| | else: |
| | total_loss = loss( |
| | audio_features=audio_features, |
| | text_features=text_features, |
| | logit_scale_a=logit_scale_a, |
| | ) |
| | if isinstance(optimizer, dict): |
| | if scaler is not None: |
| | scaler.scale(total_loss).backward() |
| | for o_ in optimizer.values(): |
| | if args.horovod: |
| | o_.synchronize() |
| | scaler.unscale_(o_) |
| | with o_.skip_synchronize(): |
| | scaler.step(o_) |
| | else: |
| | scaler.step(o_) |
| | scaler.update() |
| | else: |
| | total_loss.backward() |
| | for o_ in optimizer.values(): |
| | o_.step() |
| | else: |
| | if scaler is not None: |
| | scaler.scale(total_loss).backward() |
| | if args.horovod: |
| | optimizer.synchronize() |
| | scaler.unscale_(optimizer) |
| | with optimizer.skip_synchronize(): |
| | scaler.step(optimizer) |
| | else: |
| | scaler.step(optimizer) |
| | scaler.update() |
| | else: |
| | total_loss.backward() |
| | optimizer.step() |
| |
|
| | |
| | with torch.no_grad(): |
| | unwrap_model(model).logit_scale_a.clamp_(0, math.log(100)) |
| | if args.clap_mlploss: |
| | unwrap_model(model).logit_scale_t.clamp_(0, math.log(100)) |
| |
|
| | batch_time_m.update(time.time() - end) |
| | end = time.time() |
| | batch_count = i + 1 |
| | if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): |
| | if isinstance(audios, dict): |
| | batch_size = len(audios["waveform"]) |
| | else: |
| | batch_size = len(audios) |
| | num_samples = batch_count * batch_size * args.world_size |
| | samples_per_epoch = dataloader.num_samples |
| | percent_complete = 100.0 * batch_count / num_batches_per_epoch |
| |
|
| | |
| | loss_m.update(total_loss.item(), batch_size) |
| | logit_scale_scalar_a = logit_scale_a.item() |
| | logit_scale_scalar_t = logit_scale_t.item() |
| | if isinstance(optimizer, dict): |
| | if args.clap_mlploss: |
| | logging.info( |
| | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " |
| | f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " |
| | f"Data (t): {data_time_m.avg:.3f} " |
| | f"Batch (t): {batch_time_m.avg:.3f} " |
| | f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} " |
| | f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" |
| | f"Logit Scale Text: {logit_scale_scalar_t:.3f}" |
| | ) |
| | log_data = { |
| | "loss": loss_m.val, |
| | "data_time": data_time_m.val, |
| | "batch_time": batch_time_m.val, |
| | "scale_audio": logit_scale_scalar_a, |
| | "scale_text": logit_scale_scalar_t, |
| | "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], |
| | } |
| | else: |
| | logging.info( |
| | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " |
| | f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " |
| | f"Data (t): {data_time_m.avg:.3f} " |
| | f"Batch (t): {batch_time_m.avg:.3f} " |
| | f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} " |
| | f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" |
| | ) |
| | log_data = { |
| | "loss": loss_m.val, |
| | "data_time": data_time_m.val, |
| | "batch_time": batch_time_m.val, |
| | "scale_audio": logit_scale_scalar_a, |
| | "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], |
| | } |
| |
|
| | else: |
| | if args.clap_mlploss: |
| | logging.info( |
| | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " |
| | f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " |
| | f"Data (t): {data_time_m.avg:.3f} " |
| | f"Batch (t): {batch_time_m.avg:.3f} " |
| | f"LR: {optimizer.param_groups[0]['lr']:5f} " |
| | f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" |
| | f"Logit Scale Text: {logit_scale_scalar_t:.3f}" |
| | ) |
| |
|
| | |
| | log_data = { |
| | "loss": loss_m.val, |
| | "data_time": data_time_m.val, |
| | "batch_time": batch_time_m.val, |
| | "scale_audio": logit_scale_scalar_a, |
| | "scale_text": logit_scale_scalar_t, |
| | "lr": optimizer.param_groups[0]["lr"], |
| | } |
| | else: |
| | logging.info( |
| | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " |
| | f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " |
| | f"Data (t): {data_time_m.avg:.3f} " |
| | f"Batch (t): {batch_time_m.avg:.3f} " |
| | f"LR: {optimizer.param_groups[0]['lr']:5f} " |
| | f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" |
| | ) |
| |
|
| | |
| | log_data = { |
| | "loss": loss_m.val, |
| | "data_time": data_time_m.val, |
| | "batch_time": batch_time_m.val, |
| | "scale_audio": logit_scale_scalar_a, |
| | "lr": optimizer.param_groups[0]["lr"], |
| | } |
| | for name, val in log_data.items(): |
| | name = "train/" + name |
| | if tb_writer is not None: |
| | tb_writer.add_scalar(name, val, step) |
| | if args.wandb: |
| | assert wandb is not None, "Please install wandb." |
| | wandb.log({name: val, "step": step}) |
| |
|
| | |
| | batch_time_m.reset() |
| | data_time_m.reset() |
| | |
| |
|
| |
|
| | def evaluate(model, data, epoch, args, tb_writer=None): |
| | metrics = {} |
| | if not args.parallel_eval: |
| | if not is_master(args): |
| | return metrics |
| | device = torch.device(args.device) |
| | model.eval() |
| |
|
| | |
| | |
| | |
| | if is_master(args): |
| | print("Evaluating...") |
| | autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress |
| | if args.val_dataset_names == ["Clotho", "audiocaps"]: |
| | |
| | |
| | if args.parallel_eval: |
| | |
| | raise NotImplementedError( |
| | "Parallel evaluation not supported for eval only Clotho and audiocaps." |
| | ) |
| | val_metrics_per_dataset = evaluate_clotho_audiocaps( |
| | model, data, epoch, args, autocast, device, tb_writer |
| | ) |
| | for m in val_metrics_per_dataset.values(): |
| | metrics.update(m) |
| | if "epoch" not in metrics.keys(): |
| | metrics.update({"epoch": epoch}) |
| | metrics = select_top_metric_clotho_audiocaps( |
| | metrics, val_metrics_per_dataset, args |
| | ) |
| | elif "val" in data and ( |
| | args.val_frequency |
| | and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) |
| | ): |
| | dataloader = data["val"].dataloader |
| | num_samples = 0 |
| | samples_per_val = dataloader.num_samples |
| |
|
| | |
| | |
| | eval_info = {} |
| | if args.clap_mlploss: |
| | eval_info["all"] = { |
| | "cumulative_loss": 0.0, |
| | "num_samples": 0, |
| | "all_audio_features": [], |
| | "all_text_features": [], |
| | "all_audio_features_mlp": [], |
| | "all_text_features_mlp": [], |
| | } |
| | else: |
| | eval_info["all"] = { |
| | "cumulative_loss": 0.0, |
| | "num_samples": 0, |
| | "all_audio_features": [], |
| | "all_text_features": [], |
| | } |
| | |
| | with torch.no_grad(): |
| | for i, batch in enumerate(dataloader): |
| | audios = batch |
| | texts = batch["text"] |
| | |
| |
|
| | all_names = list( |
| | set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) |
| | ) |
| | for name in all_names: |
| | if name not in eval_info.keys(): |
| | if args.clap_mlploss: |
| | eval_info[name] = { |
| | "cumulative_loss": 0.0, |
| | "num_samples": 0, |
| | "all_audio_features": [], |
| | "all_text_features": [], |
| | "all_audio_features_mlp": [], |
| | "all_text_features_mlp": [], |
| | } |
| | else: |
| | eval_info[name] = { |
| | "cumulative_loss": 0.0, |
| | "num_samples": 0, |
| | "all_audio_features": [], |
| | "all_text_features": [], |
| | } |
| | with autocast(): |
| | ( |
| | audio_features, |
| | text_features, |
| | audio_features_mlp, |
| | text_features_mlp, |
| | logit_scale_a, |
| | logit_scale_t, |
| | ) = model(audios, texts, device) |
| |
|
| | if args.parallel_eval: |
| | |
| | if args.clap_mlploss: |
| | ( |
| | audio_features, |
| | text_features, |
| | audio_features_mlp, |
| | text_features_mlp, |
| | ) = gather_features( |
| | audio_features=audio_features, |
| | text_features=text_features, |
| | audio_features_mlp=audio_features_mlp, |
| | text_features_mlp=text_features_mlp, |
| | local_loss=False, |
| | gather_with_grad=False, |
| | rank=args.rank, |
| | world_size=args.world_size, |
| | use_horovod=args.horovod, |
| | mlp_loss=args.clap_mlploss, |
| | ) |
| | else: |
| | (audio_features, text_features,) = gather_features( |
| | audio_features=audio_features, |
| | text_features=text_features, |
| | local_loss=False, |
| | gather_with_grad=False, |
| | rank=args.rank, |
| | world_size=args.world_size, |
| | use_horovod=args.horovod, |
| | mlp_loss=args.clap_mlploss, |
| | ) |
| |
|
| | if is_master(args): |
| | num_samples += audio_features.shape[0] |
| | for n in [*all_names, "all"]: |
| | if n == "all": |
| | eval_info[n]["all_audio_features"].append( |
| | audio_features.cpu() |
| | ) |
| | eval_info[n]["all_text_features"].append( |
| | text_features.cpu() |
| | ) |
| | if args.clap_mlploss: |
| | eval_info[n]["all_audio_features_mlp"].append( |
| | audio_features_mlp.cpu() |
| | ) |
| | eval_info[n]["all_text_features_mlp"].append( |
| | text_features_mlp.cpu() |
| | ) |
| | else: |
| | idx = np.where( |
| | np.array( |
| | [ |
| | "-".join(b.split("/")[-3:-1]) |
| | for b in batch["__url__"] |
| | ] |
| | ) |
| | == n |
| | )[0] |
| | eval_info[n]["all_audio_features"].append( |
| | audio_features.cpu().index_select( |
| | 0, torch.tensor(idx).long() |
| | ) |
| | ) |
| | eval_info[n]["all_text_features"].append( |
| | text_features.cpu().index_select( |
| | 0, torch.tensor(idx).long() |
| | ) |
| | ) |
| | if args.clap_mlploss: |
| | eval_info[n]["all_audio_features_mlp"].append( |
| | audio_features_mlp.cpu().index_select( |
| | 0, torch.tensor(idx).long() |
| | ) |
| | ) |
| | eval_info[n]["all_text_features_mlp"].append( |
| | text_features_mlp.cpu().index_select( |
| | 0, torch.tensor(idx).long() |
| | ) |
| | ) |
| | |
| |
|
| | |
| | |
| | if is_master(args) and (i % 100) == 0: |
| | logging.info( |
| | f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" |
| | ) |
| | if is_master(args): |
| | val_metrics_per_dataset = {} |
| | for n in eval_info.keys(): |
| | if args.clap_mlploss: |
| | metrics_single_dataset = get_metrics( |
| | audio_features=torch.cat( |
| | eval_info[n]["all_audio_features"] |
| | ), |
| | text_features=torch.cat(eval_info[n]["all_text_features"]), |
| | logit_scale_a=logit_scale_a.cpu(), |
| | audio_features_mlp=torch.cat( |
| | eval_info[n]["all_audio_features_mlp"] |
| | ), |
| | text_features_mlp=torch.cat( |
| | eval_info[n]["all_text_features_mlp"] |
| | ), |
| | logit_scale_t=logit_scale_t.cpu(), |
| | mlp_loss=args.clap_mlploss, |
| | ) |
| | else: |
| | metrics_single_dataset = get_metrics( |
| | audio_features=torch.cat( |
| | eval_info[n]["all_audio_features"] |
| | ), |
| | text_features=torch.cat(eval_info[n]["all_text_features"]), |
| | logit_scale_a=logit_scale_a.cpu(), |
| | mlp_loss=args.clap_mlploss, |
| | ) |
| | val_metrics_per_dataset[n] = { |
| | n + "/" + k: v for k, v in metrics_single_dataset.items() |
| | } |
| | metrics.update(val_metrics_per_dataset[n]) |
| | if "epoch" not in metrics.keys(): |
| | metrics.update({"epoch": epoch}) |
| | if is_master(args): |
| | if not metrics: |
| | return metrics |
| |
|
| | logging.info( |
| | f"Eval Epoch: {epoch} " |
| | + "\n".join( |
| | [ |
| | "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in m.items()]) |
| | for m in val_metrics_per_dataset.values() |
| | ] |
| | ) |
| | ) |
| |
|
| | if args.save_logs: |
| | for name, val in metrics.items(): |
| | if tb_writer is not None: |
| | tb_writer.add_scalar(f"val/{name}", val, epoch) |
| |
|
| | with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: |
| | f.write(json.dumps(metrics)) |
| | f.write("\n") |
| |
|
| | if args.wandb: |
| | assert wandb is not None, "Please install wandb." |
| | for name, val in metrics.items(): |
| | wandb.log({f"val/{name}": val, "epoch": epoch}) |
| |
|
| | return metrics |
| | else: |
| | return metrics |
| |
|
| |
|
| | def get_metrics( |
| | audio_features, |
| | text_features, |
| | logit_scale_a, |
| | audio_features_mlp=None, |
| | text_features_mlp=None, |
| | logit_scale_t=None, |
| | mlp_loss=False, |
| | ): |
| | metrics = {} |
| | if mlp_loss: |
| | |
| | a_logits_per_audio = ( |
| | (logit_scale_a * audio_features @ text_features_mlp.t()).detach().cpu() |
| | ) |
| | a_logits_per_text = a_logits_per_audio.t().detach().cpu() |
| | t_logits_per_audio = ( |
| | (logit_scale_t * audio_features_mlp @ text_features.t()).detach().cpu() |
| | ) |
| | t_logits_per_text = t_logits_per_audio.t().detach().cpu() |
| |
|
| | labels = torch.arange(audio_features.shape[0]).long() |
| | |
| | total_loss = ( |
| | F.cross_entropy(a_logits_per_audio, labels) |
| | + F.cross_entropy(a_logits_per_text, labels) |
| | + F.cross_entropy(t_logits_per_audio, labels) |
| | + F.cross_entropy(t_logits_per_text, labels) |
| | ) / 4 |
| |
|
| | metrics[f"cumulative_loss"] = total_loss.item() |
| | metrics[f"num_samples"] = audio_features.shape[0] |
| |
|
| | logits = { |
| | "audio_to_text": (a_logits_per_audio + t_logits_per_audio) / 2, |
| | "text_to_audio": (a_logits_per_text + t_logits_per_text) / 2, |
| | } |
| | ground_truth = torch.arange(len(text_features)).view(-1, 1) |
| |
|
| | else: |
| | |
| | |
| | logits_per_audio = ( |
| | (logit_scale_a * audio_features @ text_features.t()).detach().cpu() |
| | ) |
| | logits_per_text = logits_per_audio.t().detach().cpu() |
| |
|
| | labels = torch.arange(audio_features.shape[0]).long() |
| | |
| | total_loss = ( |
| | F.cross_entropy(logits_per_audio, labels) |
| | + F.cross_entropy(logits_per_text, labels) |
| | ) / 2 |
| |
|
| | metrics[f"cumulative_loss"] = total_loss.item() |
| | metrics[f"num_samples"] = audio_features.shape[0] |
| |
|
| | logits = {"audio_to_text": logits_per_audio, "text_to_audio": logits_per_text} |
| |
|
| | ground_truth = torch.arange(len(text_features)).view(-1, 1) |
| |
|
| | for name, logit in logits.items(): |
| | ranking = torch.argsort(logit, descending=True) |
| | preds = torch.where(ranking == ground_truth)[ |
| | 1 |
| | ] |
| | preds = preds.detach().cpu().numpy() |
| | metrics[f"{name}_mean_rank"] = preds.mean() + 1 |
| | metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 |
| | for k in [1, 5, 10]: |
| | metrics[f"{name}_R@{k}"] = np.mean(preds < k) |
| | |
| | metrics[f"{name}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0)) |
| |
|
| | return metrics |
| |
|
| |
|
| | def evaluate_clotho_audiocaps( |
| | model, data, epoch, args, autocast, device, tb_writer=None |
| | ): |
| | """ |
| | Adapted from https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py. |
| | 1. for text-to-audio retrieval, do 5 times and average the results |
| | 2. for R@1, R@5, R@10 in audio-to-text retrieval, take the best rank among 5 text |
| | 3. for map@10 in audio-to-text retrieval: |
| | 3.1: sort the rank of 5 text |
| | 3.2: exclude the rank >=10 (0-index) |
| | 3.3: compute the map regarding the remaining ranks: np.mean(np.arange(1, len(ranks)+1) / ranks). |
| | (3.3) That is, take the top ranks of 5 text that is < 10, and assign the descending number as ground truth. |
| | (3.3) E.g.: the ground truth of first rank of the 5 text should be 1, the second rank should be 2, etc. |
| | """ |
| | |
| | dataloader = data["val"].dataloader |
| | with torch.no_grad(): |
| | eval_info = {} |
| | for i, batch in enumerate(dataloader): |
| | audios = batch |
| |
|
| | |
| | if args.tmodel == "transformer": |
| | from open_clip import tokenize |
| |
|
| | texts = [tokenize(t) for t in batch["full_text"]] |
| | texts = torch.cat(texts) |
| | else: |
| | from .data import tokenizer |
| |
|
| | texts = [ |
| | tokenizer(t) for t in batch["full_text"] |
| | ] |
| | texts = { |
| | k: torch.cat([t[k] for t in texts]) for k in texts[0].keys() |
| | } |
| |
|
| | |
| |
|
| | all_names = list( |
| | set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) |
| | ) |
| | for name in all_names: |
| | if name not in eval_info.keys(): |
| | |
| | eval_info[name] = { |
| | "cumulative_loss": 0.0, |
| | "num_samples": 0, |
| | "all_audio_features": [], |
| | "all_text_features": [], |
| | } |
| | with autocast(): |
| | audio_features = model(audios, None, device) |
| | text_features = model(None, texts, device) |
| | audio_features = F.normalize(audio_features, dim=-1) |
| | text_features = F.normalize(text_features, dim=-1) |
| |
|
| | all_names = list( |
| | set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) |
| | ) |
| | for n in all_names: |
| | idx = np.where( |
| | np.array( |
| | ["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]] |
| | ) |
| | == n |
| | )[0] |
| | eval_info[n]["all_audio_features"].append( |
| | audio_features.cpu().index_select(0, torch.tensor(idx).long()) |
| | ) |
| | |
| | |
| | |
| | |
| | eval_info[n]["all_text_features"].append( |
| | text_features.cpu() |
| | .reshape([-1, 5, text_features.shape[1]]) |
| | .index_select(0, torch.tensor(idx).long()) |
| | .reshape([-1, text_features.shape[1]]) |
| | ) |
| |
|
| | val_metrics_all = {} |
| |
|
| | for n in eval_info.keys(): |
| | logit_scale_a, logit_scale_t = model(None, None, device) |
| | logit_scale_a = logit_scale_a.cpu() |
| |
|
| | audio_features = torch.cat(eval_info[n]["all_audio_features"], dim=0) |
| | text_features = torch.cat(eval_info[n]["all_text_features"], dim=0) |
| |
|
| | logits_per_audio = ( |
| | (logit_scale_a * audio_features @ text_features.t()).detach().cpu() |
| | ) |
| | logits_per_text = logits_per_audio.t().detach().cpu() |
| |
|
| | |
| | |
| |
|
| | logging.info( |
| | f"dataset {n}, logits_per_audio shape: {logits_per_audio.shape}, " |
| | f"logits_per_text shape: {logits_per_text.shape}" |
| | ) |
| |
|
| | metrics = {} |
| | num_samples = audio_features.shape[0] |
| | metrics[f"num_samples"] = num_samples |
| |
|
| | |
| | |
| | |
| | |
| | labels = torch.arange(audio_features.shape[0]).long() |
| | audio_to_text_loss = [ |
| | F.cross_entropy( |
| | logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d], |
| | labels, |
| | ) |
| | for d in range(5) |
| | ] |
| | text_to_audio_loss = [ |
| | F.cross_entropy( |
| | logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :], |
| | labels, |
| | ) |
| | for d in range(5) |
| | ] |
| | total_loss = (np.mean(audio_to_text_loss) + np.mean(text_to_audio_loss)) / 2 |
| |
|
| | metrics[f"cumulative_loss"] = total_loss.item() |
| |
|
| | |
| | pred_text = [] |
| | for d in range(5): |
| | logit = logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :] |
| | ground_truth = torch.arange(len(logit)).view(-1, 1) |
| | ranking = torch.argsort( |
| | logit, descending=True |
| | ) |
| | preds = torch.where(ranking == ground_truth)[1] |
| | pred_text.append(preds.detach().cpu().numpy()) |
| | pred_text_concat = np.concatenate(pred_text, axis=0) |
| | metrics[f"text_to_audio_mean_rank"] = pred_text_concat.mean() + 1 |
| | metrics[f"text_to_audio_median_rank"] = ( |
| | np.floor(np.median(pred_text_concat)) + 1 |
| | ) |
| | for k in [1, 5, 10]: |
| | metrics[f"text_to_audio_R@{k}"] = np.mean(pred_text_concat < k) |
| | |
| | metrics[f"text_to_audio_mAP@10"] = np.mean( |
| | np.where(pred_text_concat < 10, 1 / (pred_text_concat + 1), 0.0) |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | map_all = [] |
| | pred_audio_all = [] |
| | for d in range(num_samples): |
| | |
| | logit_single = logits_per_audio[d, :] |
| | |
| | ranking = torch.argsort( |
| | logit_single, descending=True |
| | ) |
| | |
| | ground_truth = torch.arange(d * 5, d * 5 + 5)[None] |
| | all_pred = torch.where( |
| | torch.stack([ranking] * 5) == ground_truth.view(-1, 1) |
| | )[1] |
| | min_pred = torch.min(all_pred) |
| | pred_audio_all.append(min_pred.detach().cpu().numpy()) |
| | all_pred_filter = all_pred[all_pred < 10].detach().cpu().numpy() |
| | |
| | map_single = ( |
| | np.sum( |
| | (np.arange(1, len(all_pred_filter) + 1) / (all_pred_filter + 1)) |
| | ) |
| | / 5 |
| | ) |
| | map_all.append(map_single) |
| | metrics[f"audio_to_text_mAP@10"] = np.mean(map_all) |
| | for k in [1, 5, 10]: |
| | metrics[f"audio_to_text_R@{k}"] = np.mean(np.array(pred_audio_all) < k) |
| |
|
| | val_metrics_all[n] = {n + "/" + k: v for k, v in metrics.items()} |
| | return val_metrics_all |
| |
|
| |
|
| | def calculate_selection_performance_clotho_audiocaps(val_metrics_per_dataset): |
| | """ |
| | Calculate performance for Clotho+AudioCaps for model selection. |
| | """ |
| | selection_performance_all = [] |
| | for n in val_metrics_per_dataset.keys(): |
| | selection_performance = ( |
| | val_metrics_per_dataset[n][f"{n}/audio_to_text_mAP@10"] |
| | + val_metrics_per_dataset[n][f"{n}/text_to_audio_mAP@10"] |
| | ) / 2 |
| | selection_performance_all.append(selection_performance) |
| | return np.mean(selection_performance_all) |
| |
|
| |
|
| | def select_top_metric_clotho_audiocaps(metrics, val_metrics_per_dataset, args): |
| | |
| | |
| | |
| | if not hasattr(args, "top_selection_performance"): |
| | selection_performance = calculate_selection_performance_clotho_audiocaps( |
| | val_metrics_per_dataset |
| | ) |
| | |
| | metric_update = {} |
| | for n in val_metrics_per_dataset.keys(): |
| | for k in val_metrics_per_dataset[n].keys(): |
| | metric_update[ |
| | k.split("/")[0] + "-top" + "/" + k.split("/")[1] |
| | ] = val_metrics_per_dataset[n][k] |
| | metric_update["top_selection_performance"] = selection_performance |
| | metric_update["top-selection-epoch"] = metrics["epoch"] |
| | metrics.update(metric_update) |
| | args.top_metric = metric_update |
| | args.top_selection_performance = selection_performance |
| | else: |
| | selection_performance_new = calculate_selection_performance_clotho_audiocaps( |
| | val_metrics_per_dataset |
| | ) |
| | selection_performance_old = args.top_selection_performance |
| | if selection_performance_new > selection_performance_old: |
| | metric_update = {} |
| | for n in val_metrics_per_dataset.keys(): |
| | for k in val_metrics_per_dataset[n].keys(): |
| | metric_update[ |
| | k.split("/")[0] + "-top" + "/" + k.split("/")[1] |
| | ] = val_metrics_per_dataset[n][k] |
| | metric_update["top_selection_performance"] = selection_performance_new |
| | metric_update["top-selection-epoch"] = metrics["epoch"] |
| | metrics.update(metric_update) |
| | args.top_metric = metric_update |
| | args.top_selection_performance = selection_performance_new |
| | else: |
| | metrics.update(args.top_metric) |
| | return metrics |
| |
|