| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | from tqdm import tqdm |
| | import logging |
| | import os |
| | from verification import init_model, MODEL_LIST |
| | import soundfile as sf |
| | import torch |
| | import numpy as np |
| | import torch.nn.functional as F |
| | from torchaudio.transforms import Resample |
| | import torch.multiprocessing as mp |
| |
|
| | console_format = logging.Formatter( |
| | "[%(asctime)s][%(filename)s:%(levelname)s][%(process)d:%(threadName)s]%(message)s" |
| | ) |
| | console_handler = logging.StreamHandler() |
| | console_handler.setFormatter(console_format) |
| | console_handler.setLevel(logging.INFO) |
| | if len(logging.root.handlers) > 0: |
| | for handler in logging.root.handlers: |
| | logging.root.removeHandler(handler) |
| | logging.root.addHandler(console_handler) |
| | logging.root.setLevel(logging.INFO) |
| |
|
| |
|
| | MODEL_NAME = "wavlm_large" |
| | S3PRL_PATH = os.environ.get("S3PRL_PATH") |
| | if S3PRL_PATH is not None: |
| | import patch_unispeech |
| | logging.info("Applying Patches for unispeech!!!") |
| | patch_unispeech.patch_for_npu() |
| |
|
| |
|
| | def get_ref_and_gen_files( |
| | test_lst, test_folder, task_queue |
| | ): |
| | with open(test_lst, "r") as fp: |
| | for line in fp: |
| | fields = line.strip().split("|") |
| | gen_name = fields[2].split("/")[-1] |
| | gen_name = gen_name.split(".")[0] |
| | gen_file = f"{test_folder}/{gen_name}_gen.wav" |
| | |
| | ref_name = fields[0].split("/")[-1] |
| | ref_name = ref_name.split(".")[0] |
| | ref_file = f"{test_folder}/{ref_name}_ref.wav" |
| |
|
| | task_queue.put((ref_file, gen_file)) |
| |
|
| | return |
| |
|
| |
|
| | def eval_speaker_similarity(model, wav1, wav2, rank): |
| | wav1, sr1 = sf.read(wav1) |
| | wav2, sr2 = sf.read(wav2) |
| |
|
| | wav1 = torch.from_numpy(wav1).unsqueeze(0).float() |
| | wav2 = torch.from_numpy(wav2).unsqueeze(0).float() |
| | resample1 = Resample(orig_freq=sr1, new_freq=16000) |
| | resample2 = Resample(orig_freq=sr2, new_freq=16000) |
| | wav1 = resample1(wav1) |
| | wav2 = resample2(wav2) |
| |
|
| | wav1 = wav1.cuda(f"cuda:{rank}") |
| | wav2 = wav2.cuda(f"cuda:{rank}") |
| |
|
| | model.eval() |
| | with torch.no_grad(): |
| | emb1 = model(wav1) |
| | emb2 = model(wav2) |
| |
|
| | sim = F.cosine_similarity(emb1, emb2) |
| | logging.info("The similarity score between two audios is %.4f (-1.0, 1.0)." % (sim[0].item())) |
| | return sim[0].item() |
| |
|
| |
|
| | def eval_proc(model_path, task_queue, rank, sim_list): |
| | model = None |
| | assert MODEL_NAME in MODEL_LIST, 'The model_name should be in {}'.format(MODEL_LIST) |
| | model = init_model(MODEL_NAME, model_path) if model is None else model |
| | model.to(f"cuda:{rank}") |
| | |
| | |
| | while True: |
| | try: |
| | new_record = task_queue.get() |
| | if new_record is None: |
| | logging.info("FINISH processing all inputs") |
| | break |
| |
|
| | ref = new_record[0] |
| | gen = new_record[1] |
| | logging.info(f"eval SIM: {ref} v.s. {gen}") |
| |
|
| | if not os.path.exists(ref) or not os.path.exists(gen): |
| | logging.info(f"MISSING: {ref} v.s. {gen}") |
| | continue |
| |
|
| | sim = eval_speaker_similarity(model, ref, gen, rank) |
| | sim_list.append((sim, ref, gen)) |
| | except: |
| | logging.info(f"FAIL to eval SIM: {ref} v.s. {gen}") |
| | |
| |
|
| | def main(args): |
| | handler = logging.FileHandler(filename=args.log_file, mode="w") |
| | logging.root.addHandler(handler) |
| |
|
| | device_list = [0] |
| | if "CUDA_VISIBLE_DEVICES" in os.environ: |
| | device_list = [int(x.strip()) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")] |
| |
|
| | logging.info(f"Using devices: {device_list}") |
| | n_procs = len(device_list) |
| | ctx = mp.get_context('spawn') |
| | with ctx.Manager() as manager: |
| | sim_list = manager.list() |
| | task_queue = manager.Queue() |
| | get_ref_and_gen_files(args.test_lst, args.test_path, task_queue) |
| |
|
| | processes = [] |
| | for idx in range(n_procs): |
| | task_queue.put(None) |
| | rank = idx |
| | p = ctx.Process(target=eval_proc, args=(args.model_path, task_queue, rank, sim_list)) |
| | processes.append(p) |
| |
|
| | for proc in processes: |
| | proc.start() |
| | |
| | for proc in processes: |
| | proc.join() |
| | |
| | sim_scores = [] |
| | for sim, ref, gen in sim_list: |
| | logging.info(f"{ref} vs {gen} : {sim}") |
| | sim_scores.append(sim) |
| | avg_sim = round(np.mean(np.array(list(sim_scores))), 3) |
| | logging.info("total evaluated wav pairs: %d" % (len(sim_list))) |
| | logging.info("The average similarity score of %s is %.4f (-1.0, 1.0)." % (args.test_path, avg_sim)) |
| | return avg_sim |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--test-path", |
| | required=True, |
| | type=str, |
| | help=f"folder of wav files", |
| | ) |
| | parser.add_argument( |
| | "--test-lst", |
| | required=True, |
| | type=str, |
| | help=f"path to test file lst", |
| | ) |
| | parser.add_argument( |
| | "--log-file", |
| | required=False, |
| | type=str, |
| | default=None, |
| | help=f"path to test file lst", |
| | ) |
| | parser.add_argument( |
| | "--model-path", |
| | type=str, |
| | default="./wavlm-sv", |
| | help=f"path to sv model", |
| | ) |
| |
|
| | args = parser.parse_args() |
| | main(args) |
| |
|