| import numpy as np |
| import torch |
|
|
| from TTS.vocoder.models.parallel_wavegan_generator import ParallelWaveganGenerator |
|
|
|
|
| def test_pwgan_generator(): |
| model = ParallelWaveganGenerator( |
| in_channels=1, |
| out_channels=1, |
| kernel_size=3, |
| num_res_blocks=30, |
| stacks=3, |
| res_channels=64, |
| gate_channels=128, |
| skip_channels=64, |
| aux_channels=80, |
| dropout=0.0, |
| bias=True, |
| use_weight_norm=True, |
| upsample_factors=[4, 4, 4, 4], |
| ) |
| dummy_c = torch.rand((2, 80, 5)) |
| output = model(dummy_c) |
| assert np.all(output.shape == (2, 1, 5 * 256)), output.shape |
| model.remove_weight_norm() |
| output = model.inference(dummy_c) |
| assert np.all(output.shape == (2, 1, (5 + 4) * 256)) |
|
|