| import os
|
| import pretty_midi
|
| import mir_eval
|
| import numpy as np
|
| import pandas as pd
|
|
|
|
|
| DIR_GT = "dataset_evaluación/midis_gt"
|
| DIR_PRED_MY = "dataset_evaluación/resultados_midis/cornetai"
|
| DIR_PRED_OFF = "dataset_evaluación/resultados_midis/official_bp"
|
|
|
|
|
| ONSET_TOL = 0.150
|
| OFFSET_RATIO_VAL = 0.5
|
| FS = 100
|
|
|
| def get_frame_accuracy(pm_ref, pm_est):
|
| """Calcula la precisión por frame (Acc) comparando piano rolls."""
|
|
|
| pr_ref = (pm_ref.get_piano_roll(fs=FS) > 0).astype(int)
|
| pr_est = (pm_est.get_piano_roll(fs=FS) > 0).astype(int)
|
|
|
|
|
| max_len = max(pr_ref.shape[1], pr_est.shape[1])
|
| pr_ref = np.pad(pr_ref, ((0,0), (0, max_len - pr_ref.shape[1])))
|
| pr_est = np.pad(pr_est, ((0,0), (0, max_len - pr_est.shape[1])))
|
|
|
|
|
| tp = np.sum((pr_ref == 1) & (pr_est == 1))
|
| fp = np.sum((pr_ref == 0) & (pr_est == 1))
|
| fn = np.sum((pr_ref == 1) & (pr_est == 0))
|
|
|
| return (tp / (tp + fp + fn)) * 100 if (tp + fp + fn) > 0 else 0
|
|
|
| def get_full_metrics(path_ref, path_est):
|
| try:
|
| pm_ref = pretty_midi.PrettyMIDI(path_ref)
|
| pm_est = pretty_midi.PrettyMIDI(path_est)
|
|
|
|
|
| ref_int = np.array([[n.start, n.end] for n in pm_ref.instruments[0].notes])
|
| ref_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in pm_ref.instruments[0].notes])
|
| est_int = np.array([[n.start, n.end] for n in pm_est.instruments[0].notes])
|
| est_pit = np.array([pretty_midi.note_number_to_hz(n.pitch) for n in pm_est.instruments[0].notes])
|
|
|
| sc = mir_eval.transcription.evaluate(ref_int, ref_pit, est_int, est_pit,
|
| onset_tolerance=ONSET_TOL, offset_ratio=OFFSET_RATIO_VAL)
|
|
|
|
|
| acc = get_frame_accuracy(pm_ref, pm_est)
|
|
|
| return acc, sc['F-measure_no_offset'] * 100, sc['F-measure'] * 100
|
| except Exception: return 0, 0, 0
|
|
|
| def main():
|
| print("--- 📊 EVALUACIÓN FINAL TFG (ESTILO PAPER SPOTIFY) ---")
|
| res = []
|
| gts = [f for f in os.listdir(DIR_GT) if f.endswith(".mid")]
|
|
|
| for gt_file in gts:
|
| name = os.path.splitext(gt_file)[0]
|
| c_acc, c_fno, c_f = get_full_metrics(os.path.join(DIR_GT, gt_file), os.path.join(DIR_PRED_MY, name + ".mid"))
|
| o_acc, o_fno, o_f = get_full_metrics(os.path.join(DIR_GT, gt_file), os.path.join(DIR_PRED_OFF, name + ".mid"))
|
|
|
| res.append({
|
| "Archivo": name,
|
| "CAI_Acc": c_acc, "CAI_Fno": c_fno, "CAI_F": c_f,
|
| "OFF_Acc": o_acc, "OFF_Fno": o_fno, "OFF_F": o_f
|
| })
|
|
|
| df = pd.DataFrame(res)
|
| m = df.mean(numeric_only=True)
|
|
|
| print("\n" + "="*50)
|
| print(f"{'Model':<20} | {'Acc':<8} | {'Fno':<8} | {'F':<8}")
|
| print("-" * 50)
|
| print(f"{'CornetAI (V3)':<20} | {m['CAI_Acc']:<8.2f} | {m['CAI_Fno']:<8.2f} | {m['CAI_F']:<8.2f}")
|
| print(f"{'Basic Pitch':<20} | {m['OFF_Acc']:<8.2f} | {m['OFF_Fno']:<8.2f} | {m['OFF_F']:<8.2f}")
|
| print("="*50)
|
|
|
| df.to_csv("Tabla_TFG_Estilo_Paper.csv", index=False)
|
| print("\n Tabla final generada: 'Tabla_TFG_Estilo_Paper.csv'")
|
|
|
| if __name__ == "__main__":
|
| main() |