import sys import os import random import logging import yaml import numpy as np from contextlib import contextmanager import torch from speakerlab.utils.fileio import load_yaml def parse_config(config_file): if config_file.endwith('.yaml'): config = load_yaml(config_file) else: raise Exception("Other formats not currently supported.") return config def set_seed(seed=0): np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = True def get_logger(fpath=None, fmt=None): if fmt is None: fmt = "%(asctime)s - %(levelname)s: %(message)s" logging.basicConfig(level=logging.INFO, format=fmt) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) if fpath is not None: handler = logging.FileHandler(fpath) handler.setFormatter(logging.Formatter(fmt)) logger.addHandler(handler) return logger def get_utt2spk_dict(utt2spk, suffix=''): temp_dict={} with open(utt2spk,'r') as utt2spk_f: lines = utt2spk_f.readlines() for i in lines: i=i.strip().split() if suffix == '' or suffix is None: key_i = i[0] value_spk = i[1] else: key_i = i[0]+'_'+suffix value_spk = i[1]+'_'+suffix if key_i in temp_dict: raise ValueError('The key must be unique.') temp_dict[key_i]=value_spk return temp_dict def get_wavscp_dict(wavscp, suffix=''): temp_dict={} with open(wavscp, 'r') as wavscp_f: lines = wavscp_f.readlines() for i in lines: i=i.strip().split() if suffix == '' or suffix is None: key_i = i[0] else: key_i = i[0]+'_'+suffix value_path = i[1] if key_i in temp_dict: raise ValueError('The key must be unique.') temp_dict[key_i]=value_path return temp_dict def accuracy(x, target): # x: [*, C], target: [*,] _, pred = x.topk(1) pred = pred.squeeze(-1) acc = pred.eq(target).float().mean() return acc*100 def average_precision(scores, labels): # scores: [N, ], labels: [N, ] if torch.is_tensor(scores): scores = scores.cpu().numpy() if torch.is_tensor(labels): labels = labels.cpu().numpy() if isinstance(scores, list): scores = np.array(scores) if isinstance(labels, list): labels = np.array(labels) assert isinstance(scores, np.ndarray) and isinstance( labels, np.ndarray), 'Input should be numpy.array.' assert len(scores.shape)==1 and len(labels.shape)==1 and \ scores.shape[0]==labels.shape[0] sort_idx = np.argsort(scores)[::-1] scores = scores[sort_idx] labels = labels[sort_idx] tp_count = (labels==1).sum() tp = labels.cumsum() recall = tp / tp_count precision = tp / (np.arange(len(labels)) + 1) recall = np.concatenate([[0], recall, [1]]) precision = np.concatenate([[0], precision, [0]]) # Smooth precision to be monotonically decreasing. for i in range(len(precision) - 2, -1, -1): precision[i] = np.maximum(precision[i], precision[i + 1]) indices = np.where(recall[1:] != recall[:-1])[0] + 1 average_precision = np.sum( (recall[indices] - recall[indices - 1]) * precision[indices]) return average_precision def load_params(dst_model, src_state, strict=True): dst_state = {} for k in src_state: if k.startswith('module'): dst_state[k[7:]] = src_state[k] else: dst_state[k] = src_state[k] dst_model.load_state_dict(dst_state, strict=strict) return dst_model def merge_vad(vad1: list, vad2: list): intervals = vad1 + vad2 intervals.sort(key=lambda x: x[0]) merged = [] for interval in intervals: if not merged or merged[-1][1] < interval[0]: merged.append(interval) else: merged[-1][1] = max(merged[-1][1], interval[1]) return merged class AverageMeter(object): def __init__(self, name, fmt=':f'): self.name = name self.fmt = fmt self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__) class AverageMeters(object): def __init__(self, names: list = None, fmts: list = None): self.cont = dict() if names is None or fmts is None: return for name, fmt in zip(names, fmts): self.cont[name] = AverageMeter(name, fmt) def add(self, name, fmt=':f'): self.cont[name] = AverageMeter(name, fmt) def update(self, name, val, n=1): self.cont[name].update(val, n) def avg(self, name): return self.cont[name].avg def val(self, name): return self.cont[name].val def __str__(self): return '\t'.join([str(s) for s in self.cont.values()]) class ProgressMeter(object): def __init__(self, num_batches, meters, prefix=""): self.batch_fmtstr = self._get_batch_fmtstr(num_batches) self.meters = meters self.prefix = prefix def display(self, batch): entries = [self.prefix + self.batch_fmtstr.format(batch)] entries += [str(self.meters)] return '\t'.join(entries) def _get_batch_fmtstr(self, num_batches): num_digits = len(str(num_batches // 1)) fmt = '{:' + str(num_digits) + 'd}' return '[' + fmt + '/' + fmt.format(num_batches) + ']' @contextmanager def silent_print(): original_stdout = sys.stdout sys.stdout = open(os.devnull, 'w') try: yield finally: sys.stdout.close() sys.stdout = original_stdout def download_model_from_modelscope(model_id, model_revision=None, cache_dir=None): from modelscope.hub.snapshot_download import snapshot_download if cache_dir is None: cache_dir = snapshot_download( model_id, revision=model_revision, ) else: cfg_file = os.path.join(cache_dir, model_id, 'configuration.json') if not os.path.exists(cfg_file): cache_dir = snapshot_download( model_id, revision=model_revision, cache_dir=cache_dir, ) else: cache_dir = os.path.join(cache_dir, model_id) return cache_dir def circle_pad(x: torch.Tensor, target_len, dim=0): xlen = x.shape[dim] if xlen >= target_len: return x n = int(np.ceil(target_len/xlen)) xcat = torch.cat([x for _ in range(n)], dim=dim) return torch.narrow(xcat, dim, 0, target_len)