# MIDI Generation Pipeline: Text-to-Music Complete pipelines for training and inference of **text-conditioned MIDI generation** using both GPT2-style and Qwen3-based autoregressive models. ## Two Architectures ### 1. GPT2-Style (`train_midi_gpt.py`) - **From-scratch** GPT2 model with custom vocabulary - ~50M parameters (configurable) - Fast training, good for experimentation ### 2. Qwen3-0.6B (`train_midi_qwen3.py`) ⭐ Recommended - **Pretrained LLM** with vocabulary expansion (inspired by MIDI-LLM) - 751M parameters with rich text understanding - **Tied embeddings** automatically handled - Apache-2.0 license ## Files | File | Purpose | |------|---------| | `prepare_dataset.py` | Preprocess for GPT2 pipeline | | `prepare_dataset_qwen3.py` | Preprocess for Qwen3 pipeline (rich prompts) | | `train_midi_gpt.py` | Train GPT2-style model | | `train_midi_qwen3.py` | Fine-tune Qwen3-0.6B with MIDI vocab expansion | | `inference_midi_gpt.py` | Generate MIDI with GPT2 model | | `inference_midi_qwen3.py` | Generate MIDI with Qwen3 model (for local checkpoints) | | `inference_trained_model.py` | **Generate MIDI with trained model from HF Hub** | | `create_synthetic_dataset.py` | Generate synthetic test data | | `test_end_to_end.py` | Validate GPT2 pipeline | | `test_qwen3_e2e.py` | Validate Qwen3 pipeline | | `run_qwen3_training.py` | One-command GPU training script | ## Trained Model Available **[rahuldshetty/midi-qwen3-v1](https://huggingface.co/rahuldshetty/midi-qwen3-v1)** A trained Qwen3-0.6B model with expanded MIDI vocabulary (152,188 tokens total). ### Quick Inference with Trained Model ```bash # Install dependencies pip install transformers torch datasets miditok miditoolkit accelerate # Generate MIDI from a text prompt (simplest way) python inference_trained_model.py \ --model_id rahuldshetty/midi-qwen3-v1 \ --prompt "A dark electronic piece with synth strings in D minor, 140 BPM" \ --output_path my_song.mid \ --max_midi_tokens 1024 \ --temperature 1.0 \ --top_k 50 \ --top_p 0.92 # Use a random prompt from the dataset python inference_trained_model.py \ --model_id rahuldshetty/midi-qwen3-v1 \ --dataset_prompt \ --num_samples 3 \ --output_path generated.mid # Generate multiple variations python inference_trained_model.py \ --model_id rahuldshetty/midi-qwen3-v1 \ --prompt "A cheerful piano piece in C major, 120 BPM" \ --num_samples 5 \ --temperature 0.8 \ --output_path song.mid ``` ### Python API for Inference ```python import torch from transformers import AutoModelForCausalLM, AutoTokenizer from huggingface_hub import snapshot_download from miditok import REMI from pathlib import Path import json, tempfile, shutil # 1. Download model model_id = "rahuldshetty/midi-qwen3-v1" temp_dir = Path(tempfile.mkdtemp()) snapshot_download(repo_id=model_id, local_dir=str(temp_dir)) # 2. Load and expand tokenizer tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True) with open(temp_dir / "config.json") as f: config = json.load(f) expanded_vocab = config["vocab_size"] # 152188 n_added = expanded_vocab - len(tokenizer) # 519 tokens added midi_vocab_size = n_added - 3 # 516 MIDI tokens # Add tokens midi_tokens = [f"<|midi_{i}|>" for i in range(midi_vocab_size)] tokenizer.add_tokens(["<|midi_start|>", "<|midi_end|>", "<|midi_pad|>"] + midi_tokens) # 3. Load model model = AutoModelForCausalLM.from_pretrained( str(temp_dir), trust_remote_code=True, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, ) model.eval() # 4. Build prompt and tokenize prompt = "A cheerful jazz piece with piano and saxophone in C major, 120 BPM" text_ids = tokenizer.encode(prompt, add_special_tokens=False) # Special tokens bos_midi = tokenizer.convert_tokens_to_ids("<|midi_start|>") eos_midi = tokenizer.convert_tokens_to_ids("<|midi_end|>") pad_midi = tokenizer.convert_tokens_to_ids("<|midi_pad|>") midi_offset = 151936 + 3 # original_vocab + num_special # 5. Generate MIDI tokens input_tensor = torch.tensor([text_ids + [bos_midi]], dtype=torch.long, device=model.device) with torch.no_grad(): output = model.generate( input_tensor, max_new_tokens=512, do_sample=True, temperature=1.0, top_k=50, top_p=0.92, pad_token_id=pad_midi, eos_token_id=eos_midi, ) # 6. Extract and decode MIDI tokens generated = output[0].tolist() bos_idx = generated.index(bos_midi) midi_raw = generated[bos_idx + 1:] midi_raw = [t for t in midi_raw if t not in (eos_midi, pad_midi)] midi_ids = [t - midi_offset for t in midi_raw if t >= midi_offset] midi_ids = [t for t in midi_ids if 0 <= t < midi_vocab_size] # 7. Decode to MIDI file midi_tokenizer = REMI(params=str(temp_dir / "midi_tokenizer_init")) midi = midi_tokenizer.decode(midi_ids) midi.dump("output.mid") print(f"Generated {len(midi_ids)} MIDI tokens → output.mid") # Cleanup shutil.rmtree(temp_dir, ignore_errors=True) ``` ## Dataset Processed dataset: [rahuldshetty/midi-generation-dataset](https://huggingface.co/datasets/rahuldshetty/midi-generation-dataset) Source: [B-K/midi-dataset-2](https://huggingface.co/datasets/B-K/midi-dataset-2) (MidiCaps with MIDI bytes) ## Quick Start: Train Your Own ### 1. Install Dependencies ```bash pip install transformers torch datasets miditok miditoolkit accelerate ``` ### 2. Prepare Dataset ```bash python prepare_dataset_qwen3.py \ --dataset B-K/midi-dataset-2 \ --output_dir ./midi_data_qwen3 \ --max_seq_len 2048 ``` ### 3. Train ```bash python train_midi_qwen3.py \ --dataset rahuldshetty/midi-generation-dataset \ --output_dir ./midi_qwen3_model \ --num_epochs 20 \ --batch_size 2 \ --gradient_accumulation_steps 8 \ --bf16 \ --gradient_checkpointing \ --push_to_hub \ --hub_model_id yourname/midi-qwen3-v1 ``` Or use the one-command script: ```bash python run_qwen3_training.py ``` ### 4. Generate MIDI ```bash python inference_midi_qwen3.py \ --model_dir ./midi_qwen3_model/final \ --prompt "A cheerful jazz piece with piano and saxophone in C major, 120 BPM" \ --output_path output.mid \ --max_midi_tokens 1024 ``` ## Qwen3 Architecture Details ### Vocabulary Expansion - Qwen3 base vocab: **151,936** tokens - MIDI special tokens: `<|midi_start|>`, `<|midi_end|>`, `<|midi_pad|>` - MIDI vocab tokens: `<|midi_0|>` ... `<|midi_515|>` (REMI tokenization) - **Total vocab: ~152,455** ### Training Labels - Text prefix → `-100` (not trained on) - MIDI tokens + special tokens → actual IDs - Model learns only music generation ### Rich Prompt Format ``` You are a world-class composer. Please compose some music according to the following description: Description: [caption] Genre: [genre] Mood: [mood] Key: [key] Time Signature: [time_signature] Tempo: [tempo] BPM Duration: [duration] seconds Instruments: [instruments] Chords: [chords] ``` ## Recommended Datasets | Dataset | Link | Description | |---------|------|-------------| | B-K/midi-dataset-2 | [HF](https://huggingface.co/datasets/B-K/midi-dataset-2) | Best - rich metadata + MIDI bytes | | amaai-lab/MidiCaps | [HF](https://huggingface.co/datasets/amaai-lab/MidiCaps) | 168K captions (no MIDI bytes) | | foldl/midi | [HF](https://huggingface.co/datasets/foldl/midi) | Name + genre + MIDI bytes | ## Hardware Recommendations | Model | GPU | Batch | Notes | |-------|-----|-------|-------| | GPT2 (50M) | t4-small | 4 | Fast experimentation | | Qwen3-0.6B | a10g-large | 2 | Enable gradient_checkpointing | | Qwen3-0.6B | a100-large | 4 | Full training | ## SOTA References - **MIDI-LLM** (Wu et al., 2025): LLM vocab expansion for MIDI - **MIDI-GPT** (Pasquier et al., 2025): GPT2 for MIDI - **text2midi** (Bhandari et al., AAAI 2025): T5 encoder + decoder