| """ |
| Inference script specifically for the trained Qwen3 MIDI generation model. |
| |
| Auto-detects model structure from checkpoint — no metadata file needed. |
| Works with: https://huggingface.co/rahuldshetty/midi-qwen3-v1 |
| |
| Usage: |
| python inference_trained_model.py --prompt "A dark electronic piece in D minor, 140 BPM" |
| python inference_trained_model.py --dataset_prompt --num_samples 3 |
| python inference_trained_model.py --prompt "Jazz piano in C major, 120 BPM" --temperature 0.8 --max_midi_tokens 512 |
| """ |
| import argparse |
| import json |
| import tempfile |
| import shutil |
| from pathlib import Path |
|
|
| import torch |
| from datasets import load_dataset |
| from huggingface_hub import snapshot_download |
| from miditok import REMI, TokenizerConfig |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| BOS_MIDI_TOKEN = "<|midi_start|>" |
| EOS_MIDI_TOKEN = "<|midi_end|>" |
| PAD_MIDI_TOKEN = "<|midi_pad|>" |
|
|
|
|
| def setup(model_id: str, device: str = None): |
| """Download model, reconstruct tokenizer, load everything.""" |
| print(f"Downloading model: {model_id}") |
| temp_dir = Path(tempfile.mkdtemp()) |
| snapshot_download(repo_id=model_id, local_dir=str(temp_dir)) |
| print(f" Downloaded to: {temp_dir}") |
|
|
| |
| with open(temp_dir / "config.json") as f: |
| cfg = json.load(f) |
| expanded_vocab = cfg["vocab_size"] |
| print(f" Model vocab_size: {expanded_vocab}") |
|
|
| |
| print("Loading base Qwen3 tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained( |
| "Qwen/Qwen3-0.6B", trust_remote_code=True |
| ) |
| base_vocab = len(tokenizer) |
| print(f" Base vocab: {base_vocab}") |
|
|
| |
| n_added = expanded_vocab - base_vocab |
| n_special = 3 |
| midi_vocab_size = n_added - n_special |
| print(f" Added tokens: {n_added} (special={n_special}, midi={midi_vocab_size})") |
|
|
| |
| midi_tokens = [f"<|midi_{i}|>" for i in range(midi_vocab_size)] |
| tokenizer.add_tokens([BOS_MIDI_TOKEN, EOS_MIDI_TOKEN, PAD_MIDI_TOKEN] + midi_tokens) |
| print(f" Expanded tokenizer: {len(tokenizer)} tokens") |
|
|
| |
| print("Loading model...") |
| dtype = torch.bfloat16 if device == "cuda" else torch.float32 |
| model = AutoModelForCausalLM.from_pretrained( |
| str(temp_dir), |
| trust_remote_code=True, |
| torch_dtype=dtype, |
| device_map="auto" if device == "cuda" else None, |
| ) |
| if device != "cuda": |
| model = model.to(device) |
| model.eval() |
| print(f" Model on device: {model.device}") |
|
|
| |
| |
| midi_tok_dir = temp_dir / "midi_tokenizer_init" |
| midi_tokenizer = None |
|
|
| if midi_tok_dir.is_dir(): |
| |
| tok_json = midi_tok_dir / "tokenizer.json" |
| if tok_json.exists(): |
| print(f" Found MidiTok tokenizer.json: {tok_json}") |
| midi_tokenizer = REMI(params=str(tok_json)) |
| else: |
| |
| json_files = sorted(midi_tok_dir.glob("*.json")) |
| if json_files: |
| print(f" Found MidiTok config: {json_files[0]}") |
| midi_tokenizer = REMI(params=str(json_files[0])) |
| else: |
| print(f" WARNING: No .json found in {midi_tok_dir}, files: {list(midi_tok_dir.iterdir())}") |
| else: |
| print(f" WARNING: midi_tokenizer_init/ not found at {midi_tok_dir}") |
|
|
| if midi_tokenizer is None: |
| print(" Creating default MidiTok REMI tokenizer") |
| tok_cfg = TokenizerConfig( |
| num_velocities=16, use_chords=True, use_tempos=True, |
| use_time_signatures=True, use_programs=True, num_programs=128, |
| ) |
| midi_tokenizer = REMI(tok_cfg) |
|
|
| print(f" MidiTok vocab: {midi_tokenizer.vocab_size}") |
|
|
| |
| bos_id = tokenizer.convert_tokens_to_ids(BOS_MIDI_TOKEN) |
| eos_id = tokenizer.convert_tokens_to_ids(EOS_MIDI_TOKEN) |
| pad_id = tokenizer.convert_tokens_to_ids(PAD_MIDI_TOKEN) |
| midi_offset = base_vocab + n_special |
|
|
| metadata = { |
| "base_vocab": base_vocab, |
| "midi_vocab_size": midi_vocab_size, |
| "midi_offset": midi_offset, |
| "bos_id": bos_id, |
| "eos_id": eos_id, |
| "pad_id": pad_id, |
| "max_length": 2048, |
| } |
| print(f" Metadata: midi_offset={midi_offset}, bos={bos_id}, eos={eos_id}") |
|
|
| return model, tokenizer, midi_tokenizer, metadata, temp_dir |
|
|
|
|
| def generate(model, tokenizer, midi_tokenizer, metadata, prompt, max_midi_tokens, |
| temperature, top_k, top_p, device): |
| """Generate MIDI from text prompt.""" |
| bos_id = metadata["bos_id"] |
| eos_id = metadata["eos_id"] |
| pad_id = metadata["pad_id"] |
| midi_offset = metadata["midi_offset"] |
| midi_vocab_size = metadata["midi_vocab_size"] |
| max_length = metadata["max_length"] |
|
|
| |
| text_ids = tokenizer.encode(prompt, add_special_tokens=False) |
| print(f"\nPrompt: '{prompt[:200]}...'" if len(prompt) > 200 else f"\nPrompt: '{prompt}'") |
| print(f"Text tokens: {len(text_ids)}") |
|
|
| input_ids = text_ids + [bos_id] |
| generated = input_ids.copy() |
|
|
| print(f"Generating up to {max_midi_tokens} MIDI tokens...") |
| with torch.no_grad(): |
| for i in range(max_midi_tokens): |
| if len(generated) >= max_length: |
| print(f" Max length reached ({max_length})") |
| break |
|
|
| inp = torch.tensor([generated], dtype=torch.long, device=device) |
| logits = model(inp).logits[:, -1, :] |
| logits = logits / temperature |
|
|
| |
| if top_k > 0: |
| idx_rm = logits < torch.topk(logits, min(top_k, logits.size(-1)))[0][..., -1, None] |
| logits[idx_rm] = float("-inf") |
|
|
| |
| if top_p < 1.0: |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| cumsum = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
| rm = cumsum > top_p |
| rm[..., 1:] = rm[..., :-1].clone() |
| rm[..., 0] = False |
| rm = rm.scatter(-1, sorted_indices, rm) |
| logits[rm] = float("-inf") |
|
|
| probs = torch.softmax(logits, dim=-1) |
| next_tok = torch.multinomial(probs, num_samples=1).item() |
| generated.append(next_tok) |
|
|
| if next_tok == eos_id: |
| print(f" EOS at step {i+1}") |
| break |
| if next_tok == pad_id: |
| print(f" PAD at step {i+1}") |
| break |
|
|
| |
| try: |
| bos_idx = generated.index(bos_id) |
| except ValueError: |
| print("ERROR: BOS token not found!") |
| return None |
|
|
| raw = generated[bos_idx + 1:] |
| raw = [t for t in raw if t not in (eos_id, pad_id)] |
| midi_ids = [t - midi_offset for t in raw if t >= midi_offset] |
| midi_ids = [t for t in midi_ids if 0 <= t < midi_vocab_size] |
|
|
| print(f"Extracted {len(midi_ids)} MIDI tokens") |
| if not midi_ids: |
| print("ERROR: No valid MIDI tokens!") |
| return None |
|
|
| midi = midi_tokenizer.decode(midi_ids) |
| return midi |
|
|
|
|
| def save_midi(midi_score, output_path: str): |
| """Save a MidiTok decoded score to a MIDI file. |
| |
| MidiTok v3 returns a symusic ScoreTick object which uses dump_midi() |
| instead of the old miditoolkit MidiFile.dump() method. |
| """ |
| out = Path(output_path) |
| out.parent.mkdir(parents=True, exist_ok=True) |
| |
| |
| if hasattr(midi_score, "dump_midi"): |
| midi_score.dump_midi(str(out)) |
| |
| elif hasattr(midi_score, "dump"): |
| midi_score.dump(str(out)) |
| else: |
| raise AttributeError( |
| f"Cannot save MIDI object of type {type(midi_score)}. " |
| f"Available attrs: {[a for a in dir(midi_score) if not a.startswith('_')]}" |
| ) |
| |
| size = out.stat().st_size |
| print(f"Saved: {out} ({size} bytes)") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_id", type=str, default="rahuldshetty/midi-qwen3-v1") |
| parser.add_argument("--prompt", type=str, default=None) |
| parser.add_argument("--dataset_prompt", action="store_true") |
| parser.add_argument("--output_path", type=str, default="generated.mid") |
| parser.add_argument("--max_midi_tokens", type=int, default=1024) |
| parser.add_argument("--temperature", type=float, default=1.0) |
| parser.add_argument("--top_k", type=int, default=50) |
| parser.add_argument("--top_p", type=float, default=0.92) |
| parser.add_argument("--num_samples", type=int, default=1) |
| parser.add_argument("--device", type=str, default=None) |
| parser.add_argument("--seed", type=int, default=None) |
| args = parser.parse_args() |
|
|
| device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}") |
| if args.seed is not None: |
| torch.manual_seed(args.seed) |
|
|
| model, tokenizer, midi_tokenizer, metadata, temp_dir = setup(args.model_id, device) |
|
|
| |
| if args.dataset_prompt: |
| ds = load_dataset("rahuldshetty/midi-generation-dataset", split="train") |
| import random |
| sample = ds[random.randint(0, len(ds) - 1)] |
| prompt = sample["prompt"] |
| print(f"\nUsing dataset prompt (first 200 chars):\n{prompt[:200]}...") |
| elif args.prompt: |
| prompt = args.prompt |
| else: |
| prompt = "A cheerful piano piece in C major, 120 BPM, classical style" |
|
|
| |
| for i in range(args.num_samples): |
| out = Path(args.output_path) |
| if args.num_samples > 1: |
| out = out.with_name(f"{out.stem}_{i+1}{out.suffix}") |
|
|
| midi = generate( |
| model, tokenizer, midi_tokenizer, metadata, |
| prompt, args.max_midi_tokens, |
| args.temperature, args.top_k, args.top_p, device, |
| ) |
|
|
| if midi: |
| save_midi(midi, str(out)) |
| else: |
| print(f"Sample {i+1} failed") |
|
|
| shutil.rmtree(temp_dir, ignore_errors=True) |
| print("\nDone!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|