| | from argparse import ArgumentParser
|
| | from tqdm import tqdm
|
| |
|
| | import numpy as np
|
| | import matplotlib.pyplot as plt
|
| | import seaborn as sns
|
| |
|
| | from anticipation.vocab import MIDI_TIME_OFFSET, MIDI_START_OFFSET, TIME_RESOLUTION, SEPARATOR
|
| | from anticipation.ops import max_time
|
| |
|
| | plt.rcParams['font.family'] = 'serif'
|
| | plt.rcParams['font.serif'] = ['Computer Modern']
|
| | plt.rcParams['font.size'] = 16
|
| |
|
| | def loghist(filename, data, title, xlabel):
|
| | sns.set_style('whitegrid')
|
| | plt.clf()
|
| | plt.figure(figsize=(10,4))
|
| |
|
| | plt.xscale('log')
|
| | plt.xlabel(xlabel)
|
| | plt.ylabel('Density')
|
| |
|
| | plt.grid(True, which='both', linestyle='-', linewidth=0.5)
|
| |
|
| | density = sns.kdeplot(data, bw_adjust=1.0)
|
| |
|
| | plt.tight_layout()
|
| | fig = density.get_figure()
|
| | fig.savefig(filename, dpi=300)
|
| |
|
| |
|
| | if __name__ == '__main__':
|
| | parser = ArgumentParser(description='calculate statistics of a tokenized MIDI dataset')
|
| | parser.add_argument('-f', '--filename',
|
| | help='file containing a tokenized MIDI dataset')
|
| | parser.add_argument('-i', '--interarrival',
|
| | action='store_true',
|
| | help='request interarrival-time enocoding (default to arrival-time encoding)')
|
| | args = parser.parse_args()
|
| |
|
| | print(f'Calculating statistics for {args.filename}')
|
| | time_lengths = []
|
| | token_counts = []
|
| | with open(args.filename, 'r') as f:
|
| | for i,line in tqdm(list(enumerate(f))):
|
| | if i % 10 != 0: continue
|
| | tokens = [int(token) for token in line.split()]
|
| |
|
| | if args.interarrival:
|
| | time_lengths.append(sum(t-MIDI_TIME_OFFSET for t in tokens if t < MIDI_START_OFFSET))
|
| | token_counts.append(len(tokens))
|
| | else:
|
| | if SEPARATOR in tokens:
|
| | continue
|
| | time_lengths.append(max_time(tokens[1:], seconds=False))
|
| | token_counts.append(len(tokens[1:]))
|
| |
|
| | tokens_per_second = [TIME_RESOLUTION*tokens/float(time) for (tokens, time) in zip(token_counts, time_lengths)]
|
| | print('Total tokens:', sum(token_counts))
|
| | print(f'Total time: {float(sum(time_lengths))/(3600*TIME_RESOLUTION)} hours')
|
| | print('Mean tokens-per-second:', TIME_RESOLUTION*sum(token_counts)/float(sum(time_lengths)))
|
| | print('Std tokens-per-second:', np.std(tokens_per_second))
|
| | print(np.mean(tokens_per_second))
|
| |
|
| | loghist('output/tokens_per_second.png',
|
| | tokens_per_second,
|
| | 'Distribution of Tokens per Second',
|
| | 'Tokens per Second (log10 scale)')
|
| |
|