| | import csv |
| | import random |
| | from collections import defaultdict |
| | from pathlib import Path |
| |
|
| | import click |
| | import yaml |
| |
|
| |
|
| | |
| | @click.command(help='Randomly select test samples') |
| | @click.argument( |
| | 'config', |
| | type=click.Path(file_okay=True, dir_okay=False, resolve_path=True, writable=True, path_type=Path), |
| | metavar="CONFIG" |
| | ) |
| | @click.option( |
| | '--rel_path', |
| | type=click.Path(file_okay=False, dir_okay=True, resolve_path=True, path_type=Path), |
| | default=None, |
| | help='Path that is relative to the paths mentioned in the config file.' |
| | ) |
| | @click.option( |
| | '--min', '_min', |
| | show_default=True, |
| | type=click.IntRange(min=1), |
| | default=10, |
| | help='Minimum number of test samples.' |
| | ) |
| | @click.option( |
| | '--max', '_max', |
| | show_default=True, |
| | type=click.IntRange(min=1), |
| | default=20, |
| | help='Maximum number of test samples (note that each speaker will have at least one test sample).' |
| | ) |
| | @click.option( |
| | '--per_speaker', |
| | show_default=True, |
| | type=click.IntRange(min=1), |
| | default=4, |
| | help='Expected number of test samples per speaker.' |
| | ) |
| | def select_test_set(config, rel_path, _min, _max, per_speaker): |
| | assert _min <= _max, 'min must be smaller or equal to max' |
| | with open(config, 'r', encoding='utf8') as f: |
| | hparams = yaml.safe_load(f) |
| |
|
| | spk_map = None |
| | spk_ids = hparams['spk_ids'] |
| | speakers = hparams['speakers'] |
| | raw_data_dirs = list(map(Path, hparams['raw_data_dir'])) |
| | assert isinstance(speakers, list), 'Speakers must be a list' |
| | assert len(speakers) == len(raw_data_dirs), \ |
| | 'Number of raw data dirs must equal number of speaker names!' |
| | if not spk_ids: |
| | spk_ids = list(range(len(raw_data_dirs))) |
| | else: |
| | assert len(spk_ids) == len(raw_data_dirs), \ |
| | 'Length of explicitly given spk_ids must equal the number of raw datasets.' |
| | assert max(spk_ids) < hparams['num_spk'], \ |
| | f'Index in spk_id sequence {spk_ids} is out of range. All values should be smaller than num_spk.' |
| |
|
| | spk_map = {} |
| | path_spk_map = defaultdict(list) |
| | for ds_id, (spk_name, raw_path, spk_id) in enumerate(zip(speakers, raw_data_dirs, spk_ids)): |
| | if spk_name in spk_map and spk_map[spk_name] != spk_id: |
| | raise ValueError(f'Invalid speaker ID assignment. Name \'{spk_name}\' is assigned ' |
| | f'with different speaker IDs: {spk_map[spk_name]} and {spk_id}.') |
| | spk_map[spk_name] = spk_id |
| | path_spk_map[spk_id].append((ds_id, rel_path / raw_path if rel_path else raw_path)) |
| |
|
| | training_cases = [] |
| | for spk_raw_dirs in path_spk_map.values(): |
| | training_case = [] |
| | |
| | for ds_id, raw_data_dir in spk_raw_dirs: |
| | with open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf8') as f: |
| | reader = csv.DictReader(f) |
| | for row in reader: |
| | if (raw_data_dir / 'wavs' / f'{row["name"]}.wav').exists(): |
| | training_case.append(f'{ds_id}:{row["name"]}') |
| | training_cases.append(training_case) |
| |
|
| | test_prefixes = [] |
| | total = min(_max, max(_min, per_speaker * len(training_cases))) |
| | quotient, remainder = total // len(training_cases), total % len(training_cases) |
| | if quotient == 0: |
| | test_counts = [1] * len(training_cases) |
| | else: |
| | test_counts = [quotient + 1] * remainder + [quotient] * (len(training_cases) - remainder) |
| | for i, count in enumerate(test_counts): |
| | test_prefixes += sorted(random.sample(training_cases[i], count)) |
| | if not hparams['test_prefixes'] or click.confirm('Overwrite existing test prefixes?', abort=False): |
| | hparams['test_prefixes'] = test_prefixes |
| | hparams['num_valid_plots'] = len(test_prefixes) |
| | with open(config, 'w', encoding='utf8') as f: |
| | yaml.dump(hparams, f, sort_keys=False) |
| | print('Test prefixes saved.') |
| | else: |
| | print('Test prefixes not saved, aborted.') |
| |
|
| | if __name__ == '__main__': |
| | select_test_set() |
| |
|