Cocktail Party AI - Conv-TasNet 3-Source Separator

This repository contains the best checkpoint from a Conv-TasNet model trained for speech source separation. The model takes a mixed speech waveform and estimates 3 separated source waveforms.

Checkpoint

  • File: best.ckpt
  • Architecture: Asteroid ConvTasNet
  • Number of sources: 3
  • Sample rate: 16 kHz
  • Training checkpoint epoch: 68
  • Best validation loss: -2.909952
  • Approximate validation SI-SNR: 2.91 dB

Files

best.ckpt
configs/data.yaml
configs/train.yaml
requirements.txt
src/model.py
src/separate.py

Usage

Install dependencies:

pip install -r requirements.txt

Load the checkpoint with the project code:

import yaml
import torch

from src.model import build_model, load_checkpoint

with open("configs/train.yaml") as f:
    train_cfg = yaml.safe_load(f)
with open("configs/data.yaml") as f:
    data_cfg = yaml.safe_load(f)

mod = train_cfg["model"]
ds = data_cfg["dataset"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = build_model(
    n_src=ds["n_src"],
    sample_rate=ds["sample_rate"],
    n_filters=mod["n_filters"],
    filter_length=mod["filter_length"],
    stride=mod["stride"],
    n_blocks=mod["n_blocks"],
    n_repeats=mod["n_repeats"],
    bn_chan=mod["bn_chan"],
    hid_chan=mod["hid_chan"],
    skip_chan=mod["skip_chan"],
    norm_type=mod["norm_type"],
    mask_act=mod["mask_act"],
    use_gradient_checkpointing=False,
).to(device)

load_checkpoint(model, "best.ckpt", device)
model.eval()

To separate a WAV file using this project:

python src/separate.py --mix path/to/mixture.wav --ckpt best.ckpt

Notes

This is a research/training checkpoint, not a fully packaged transformers pipeline. It depends on PyTorch, Torchaudio, and Asteroid.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support