Cocktail Party AI - Conv-TasNet 3-Source Separator
This repository contains the best checkpoint from a Conv-TasNet model trained for speech source separation. The model takes a mixed speech waveform and estimates 3 separated source waveforms.
Checkpoint
- File:
best.ckpt - Architecture: Asteroid
ConvTasNet - Number of sources: 3
- Sample rate: 16 kHz
- Training checkpoint epoch: 68
- Best validation loss: -2.909952
- Approximate validation SI-SNR: 2.91 dB
Files
best.ckpt
configs/data.yaml
configs/train.yaml
requirements.txt
src/model.py
src/separate.py
Usage
Install dependencies:
pip install -r requirements.txt
Load the checkpoint with the project code:
import yaml
import torch
from src.model import build_model, load_checkpoint
with open("configs/train.yaml") as f:
train_cfg = yaml.safe_load(f)
with open("configs/data.yaml") as f:
data_cfg = yaml.safe_load(f)
mod = train_cfg["model"]
ds = data_cfg["dataset"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = build_model(
n_src=ds["n_src"],
sample_rate=ds["sample_rate"],
n_filters=mod["n_filters"],
filter_length=mod["filter_length"],
stride=mod["stride"],
n_blocks=mod["n_blocks"],
n_repeats=mod["n_repeats"],
bn_chan=mod["bn_chan"],
hid_chan=mod["hid_chan"],
skip_chan=mod["skip_chan"],
norm_type=mod["norm_type"],
mask_act=mod["mask_act"],
use_gradient_checkpointing=False,
).to(device)
load_checkpoint(model, "best.ckpt", device)
model.eval()
To separate a WAV file using this project:
python src/separate.py --mix path/to/mixture.wav --ckpt best.ckpt
Notes
This is a research/training checkpoint, not a fully packaged transformers pipeline.
It depends on PyTorch, Torchaudio, and Asteroid.