| | import torch |
| | import os |
| | import pandas as pd |
| | from tqdm import tqdm |
| | import sed_scores_eval |
| | from desed_task.evaluation.evaluation_measures import (compute_per_intersection_macro_f1, |
| | compute_psds_from_operating_points, |
| | compute_psds_from_scores) |
| | from local.utils import (batched_decode_preds,) |
| | from utils.sed import Encoder |
| | import numpy as np |
| |
|
| |
|
| | @torch.no_grad() |
| | def val_psds(model, val_loader, params, epoch, split, save_path, device): |
| | label_df = pd.read_csv(params['data'][split]['label']) |
| | EVENTS = label_df['label'].tolist() |
| |
|
| | clap_emb = [] |
| | for event in EVENTS: |
| | cls = torch.load(params['data']['train_data']['clap_dir'] + event + '.pt').to(device) |
| | cls = cls.unsqueeze(1) |
| | clap_emb.append(cls) |
| | cls = torch.cat(clap_emb, dim=1) |
| |
|
| | encoder = Encoder(EVENTS, audio_len=10, frame_len=160, frame_hop=160, net_pooling=4, sr=16000) |
| |
|
| | model.eval() |
| | test_csv = params['data'][split]["csv"] |
| | test_dur = params['data'][split]["dur"] |
| |
|
| | gt = pd.read_csv(test_csv, sep='\t') |
| |
|
| | test_scores_postprocessed_buffer = {} |
| | test_scores_postprocessed_buffer_tsed = {} |
| | test_thresholds = [0.5] |
| | test_psds_buffer = {k: pd.DataFrame() for k in test_thresholds} |
| | test_psds_buffer_tsed = {k: pd.DataFrame() for k in test_thresholds} |
| |
|
| | for batch in tqdm(val_loader): |
| | audio, filenames = batch |
| | B = audio.shape[0] |
| | N = cls.shape[1] |
| | cls = cls.expand(B, -1, -1) |
| |
|
| | audio = audio.to(device) |
| | mel = model.forward_to_spec(audio) |
| |
|
| | preds = model(mel, cls) |
| | preds = torch.sigmoid(preds) |
| | preds = preds.reshape(B, N, -1) |
| | preds_tsed = preds.clone() |
| | |
| | for idx, filename in enumerate(filenames): |
| | weak_label = list(gt[gt['filename'] == filename]['event_label'].unique()) |
| | for j, event in enumerate(EVENTS): |
| | if event not in weak_label: |
| | preds_tsed[idx][j] = 0.0 |
| | |
| |
|
| | (_, scores_postprocessed_strong, _,) = \ |
| | batched_decode_preds( |
| | preds, |
| | filenames, |
| | encoder, |
| | median_filter=9, |
| | thresholds=list(test_psds_buffer.keys()), ) |
| | test_scores_postprocessed_buffer.update(scores_postprocessed_strong) |
| |
|
| | (_, scores_postprocessed_strong_tsed, _,) = \ |
| | batched_decode_preds( |
| | preds_tsed, |
| | filenames, |
| | encoder, |
| | median_filter=9, |
| | thresholds=list(test_psds_buffer_tsed.keys()), ) |
| | test_scores_postprocessed_buffer_tsed.update(scores_postprocessed_strong_tsed) |
| |
|
| | ground_truth = sed_scores_eval.io.read_ground_truth_events(test_csv) |
| | audio_durations = sed_scores_eval.io.read_audio_durations(test_dur) |
| |
|
| | ground_truth = { |
| | audio_id: ground_truth[audio_id] |
| | for audio_id in test_scores_postprocessed_buffer |
| | } |
| | audio_durations = { |
| | audio_id: audio_durations[audio_id] |
| | for audio_id in test_scores_postprocessed_buffer |
| | } |
| |
|
| | psds1_sed_scores_eval, psds1_cls = compute_psds_from_scores( |
| | test_scores_postprocessed_buffer, |
| | ground_truth, |
| | audio_durations, |
| | dtc_threshold=0.7, |
| | gtc_threshold=0.7, |
| | cttc_threshold=None, |
| | alpha_ct=0.0, |
| | alpha_st=0.0, |
| | |
| | ) |
| | psds1_cls['overall'] = psds1_sed_scores_eval |
| | psds1_cls['macro_averaged'] = np.array([v for k, v in psds1_cls.items()]).mean() |
| | psds1_cls['name'] = 'psds1' |
| |
|
| | psds1_sed_scores_eval_tsed, psds1_cls_tsed = compute_psds_from_scores( |
| | test_scores_postprocessed_buffer_tsed, |
| | ground_truth, |
| | audio_durations, |
| | dtc_threshold=0.7, |
| | gtc_threshold=0.7, |
| | cttc_threshold=None, |
| | alpha_ct=0.0, |
| | alpha_st=0.0, |
| | |
| | ) |
| |
|
| | psds1_cls_tsed['overall'] = psds1_sed_scores_eval_tsed |
| | psds1_cls_tsed['macro_averaged'] = np.array([v for k, v in psds1_cls_tsed.items()]).mean() |
| | psds1_cls_tsed['name'] = 'psds1_tsed' |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | psds_cls = pd.DataFrame([psds1_cls, psds1_cls_tsed]) |
| | |
| | os.makedirs(f'{save_path}/psds_cls/', exist_ok=True) |
| | psds_cls.to_csv(f'{save_path}/psds_cls/{epoch}.csv', index=False) |
| |
|
| | return psds1_sed_scores_eval, psds1_sed_scores_eval_tsed |