| # MIDI Generation Pipeline: Text-to-Music |
|
|
| Complete pipelines for training and inference of **text-conditioned MIDI generation** using both GPT2-style and Qwen3-based autoregressive models. |
|
|
| ## Two Architectures |
|
|
| ### 1. GPT2-Style (`train_midi_gpt.py`) |
| - **From-scratch** GPT2 model with custom vocabulary |
| - ~50M parameters (configurable) |
| - Fast training, good for experimentation |
|
|
| ### 2. Qwen3-0.6B (`train_midi_qwen3.py`) ⭐ Recommended |
| - **Pretrained LLM** with vocabulary expansion (inspired by MIDI-LLM) |
| - 751M parameters with rich text understanding |
| - **Tied embeddings** automatically handled |
| - Apache-2.0 license |
|
|
| ## Files |
|
|
| | File | Purpose | |
| |------|---------| |
| | `prepare_dataset.py` | Preprocess for GPT2 pipeline | |
| | `prepare_dataset_qwen3.py` | Preprocess for Qwen3 pipeline (rich prompts) | |
| | `train_midi_gpt.py` | Train GPT2-style model | |
| | `train_midi_qwen3.py` | Fine-tune Qwen3-0.6B with MIDI vocab expansion | |
| | `inference_midi_gpt.py` | Generate MIDI with GPT2 model | |
| | `inference_midi_qwen3.py` | Generate MIDI with Qwen3 model (for local checkpoints) | |
| | `inference_trained_model.py` | **Generate MIDI with trained model from HF Hub** | |
| | `create_synthetic_dataset.py` | Generate synthetic test data | |
| | `test_end_to_end.py` | Validate GPT2 pipeline | |
| | `test_qwen3_e2e.py` | Validate Qwen3 pipeline | |
| | `run_qwen3_training.py` | One-command GPU training script | |
|
|
| ## Trained Model Available |
|
|
| **[rahuldshetty/midi-qwen3-v1](https://huggingface.co/rahuldshetty/midi-qwen3-v1)** |
|
|
| A trained Qwen3-0.6B model with expanded MIDI vocabulary (152,188 tokens total). |
|
|
| ### Quick Inference with Trained Model |
|
|
| ```bash |
| # Install dependencies |
| pip install transformers torch datasets miditok miditoolkit accelerate |
| |
| # Generate MIDI from a text prompt (simplest way) |
| python inference_trained_model.py \ |
| --model_id rahuldshetty/midi-qwen3-v1 \ |
| --prompt "A dark electronic piece with synth strings in D minor, 140 BPM" \ |
| --output_path my_song.mid \ |
| --max_midi_tokens 1024 \ |
| --temperature 1.0 \ |
| --top_k 50 \ |
| --top_p 0.92 |
| |
| # Use a random prompt from the dataset |
| python inference_trained_model.py \ |
| --model_id rahuldshetty/midi-qwen3-v1 \ |
| --dataset_prompt \ |
| --num_samples 3 \ |
| --output_path generated.mid |
| |
| # Generate multiple variations |
| python inference_trained_model.py \ |
| --model_id rahuldshetty/midi-qwen3-v1 \ |
| --prompt "A cheerful piano piece in C major, 120 BPM" \ |
| --num_samples 5 \ |
| --temperature 0.8 \ |
| --output_path song.mid |
| ``` |
|
|
| ### Python API for Inference |
|
|
| ```python |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from huggingface_hub import snapshot_download |
| from miditok import REMI |
| from pathlib import Path |
| import json, tempfile, shutil |
| |
| # 1. Download model |
| model_id = "rahuldshetty/midi-qwen3-v1" |
| temp_dir = Path(tempfile.mkdtemp()) |
| snapshot_download(repo_id=model_id, local_dir=str(temp_dir)) |
| |
| # 2. Load and expand tokenizer |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True) |
| with open(temp_dir / "config.json") as f: |
| config = json.load(f) |
| expanded_vocab = config["vocab_size"] # 152188 |
| n_added = expanded_vocab - len(tokenizer) # 519 tokens added |
| midi_vocab_size = n_added - 3 # 516 MIDI tokens |
| |
| # Add tokens |
| midi_tokens = [f"<|midi_{i}|>" for i in range(midi_vocab_size)] |
| tokenizer.add_tokens(["<|midi_start|>", "<|midi_end|>", "<|midi_pad|>"] + midi_tokens) |
| |
| # 3. Load model |
| model = AutoModelForCausalLM.from_pretrained( |
| str(temp_dir), |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
| device_map="auto" if torch.cuda.is_available() else None, |
| ) |
| model.eval() |
| |
| # 4. Build prompt and tokenize |
| prompt = "A cheerful jazz piece with piano and saxophone in C major, 120 BPM" |
| text_ids = tokenizer.encode(prompt, add_special_tokens=False) |
| |
| # Special tokens |
| bos_midi = tokenizer.convert_tokens_to_ids("<|midi_start|>") |
| eos_midi = tokenizer.convert_tokens_to_ids("<|midi_end|>") |
| pad_midi = tokenizer.convert_tokens_to_ids("<|midi_pad|>") |
| midi_offset = 151936 + 3 # original_vocab + num_special |
| |
| # 5. Generate MIDI tokens |
| input_tensor = torch.tensor([text_ids + [bos_midi]], dtype=torch.long, device=model.device) |
| with torch.no_grad(): |
| output = model.generate( |
| input_tensor, |
| max_new_tokens=512, |
| do_sample=True, |
| temperature=1.0, |
| top_k=50, |
| top_p=0.92, |
| pad_token_id=pad_midi, |
| eos_token_id=eos_midi, |
| ) |
| |
| # 6. Extract and decode MIDI tokens |
| generated = output[0].tolist() |
| bos_idx = generated.index(bos_midi) |
| midi_raw = generated[bos_idx + 1:] |
| midi_raw = [t for t in midi_raw if t not in (eos_midi, pad_midi)] |
| midi_ids = [t - midi_offset for t in midi_raw if t >= midi_offset] |
| midi_ids = [t for t in midi_ids if 0 <= t < midi_vocab_size] |
| |
| # 7. Decode to MIDI file |
| midi_tokenizer = REMI(params=str(temp_dir / "midi_tokenizer_init")) |
| midi = midi_tokenizer.decode(midi_ids) |
| midi.dump("output.mid") |
| print(f"Generated {len(midi_ids)} MIDI tokens → output.mid") |
| |
| # Cleanup |
| shutil.rmtree(temp_dir, ignore_errors=True) |
| ``` |
|
|
| ## Dataset |
|
|
| Processed dataset: [rahuldshetty/midi-generation-dataset](https://huggingface.co/datasets/rahuldshetty/midi-generation-dataset) |
|
|
| Source: [B-K/midi-dataset-2](https://huggingface.co/datasets/B-K/midi-dataset-2) (MidiCaps with MIDI bytes) |
|
|
| ## Quick Start: Train Your Own |
|
|
| ### 1. Install Dependencies |
|
|
| ```bash |
| pip install transformers torch datasets miditok miditoolkit accelerate |
| ``` |
|
|
| ### 2. Prepare Dataset |
|
|
| ```bash |
| python prepare_dataset_qwen3.py \ |
| --dataset B-K/midi-dataset-2 \ |
| --output_dir ./midi_data_qwen3 \ |
| --max_seq_len 2048 |
| ``` |
|
|
| ### 3. Train |
|
|
| ```bash |
| python train_midi_qwen3.py \ |
| --dataset rahuldshetty/midi-generation-dataset \ |
| --output_dir ./midi_qwen3_model \ |
| --num_epochs 20 \ |
| --batch_size 2 \ |
| --gradient_accumulation_steps 8 \ |
| --bf16 \ |
| --gradient_checkpointing \ |
| --push_to_hub \ |
| --hub_model_id yourname/midi-qwen3-v1 |
| ``` |
|
|
| Or use the one-command script: |
| ```bash |
| python run_qwen3_training.py |
| ``` |
|
|
| ### 4. Generate MIDI |
|
|
| ```bash |
| python inference_midi_qwen3.py \ |
| --model_dir ./midi_qwen3_model/final \ |
| --prompt "A cheerful jazz piece with piano and saxophone in C major, 120 BPM" \ |
| --output_path output.mid \ |
| --max_midi_tokens 1024 |
| ``` |
|
|
| ## Qwen3 Architecture Details |
|
|
| ### Vocabulary Expansion |
| - Qwen3 base vocab: **151,936** tokens |
| - MIDI special tokens: `<|midi_start|>`, `<|midi_end|>`, `<|midi_pad|>` |
| - MIDI vocab tokens: `<|midi_0|>` ... `<|midi_515|>` (REMI tokenization) |
| - **Total vocab: ~152,455** |
|
|
| ### Training Labels |
| - Text prefix → `-100` (not trained on) |
| - MIDI tokens + special tokens → actual IDs |
| - Model learns only music generation |
|
|
| ### Rich Prompt Format |
| ``` |
| You are a world-class composer. Please compose some music according to the following description: |
| Description: [caption] |
| Genre: [genre] |
| Mood: [mood] |
| Key: [key] |
| Time Signature: [time_signature] |
| Tempo: [tempo] BPM |
| Duration: [duration] seconds |
| Instruments: [instruments] |
| Chords: [chords] |
| ``` |
|
|
| ## Recommended Datasets |
|
|
| | Dataset | Link | Description | |
| |---------|------|-------------| |
| | B-K/midi-dataset-2 | [HF](https://huggingface.co/datasets/B-K/midi-dataset-2) | Best - rich metadata + MIDI bytes | |
| | amaai-lab/MidiCaps | [HF](https://huggingface.co/datasets/amaai-lab/MidiCaps) | 168K captions (no MIDI bytes) | |
| | foldl/midi | [HF](https://huggingface.co/datasets/foldl/midi) | Name + genre + MIDI bytes | |
|
|
| ## Hardware Recommendations |
|
|
| | Model | GPU | Batch | Notes | |
| |-------|-----|-------|-------| |
| | GPT2 (50M) | t4-small | 4 | Fast experimentation | |
| | Qwen3-0.6B | a10g-large | 2 | Enable gradient_checkpointing | |
| | Qwen3-0.6B | a100-large | 4 | Full training | |
| |
| ## SOTA References |
| |
| - **MIDI-LLM** (Wu et al., 2025): LLM vocab expansion for MIDI |
| - **MIDI-GPT** (Pasquier et al., 2025): GPT2 for MIDI |
| - **text2midi** (Bhandari et al., AAAI 2025): T5 encoder + decoder |
| |