""" Inference script specifically for the trained Qwen3 MIDI generation model. Auto-detects model structure from checkpoint — no metadata file needed. Works with: https://huggingface.co/rahuldshetty/midi-qwen3-v1 Usage: python inference_trained_model.py --prompt "A dark electronic piece in D minor, 140 BPM" python inference_trained_model.py --dataset_prompt --num_samples 3 python inference_trained_model.py --prompt "Jazz piano in C major, 120 BPM" --temperature 0.8 --max_midi_tokens 512 """ import argparse import json import tempfile import shutil from pathlib import Path import torch from datasets import load_dataset from huggingface_hub import snapshot_download from miditok import REMI, TokenizerConfig from transformers import AutoModelForCausalLM, AutoTokenizer BOS_MIDI_TOKEN = "<|midi_start|>" EOS_MIDI_TOKEN = "<|midi_end|>" PAD_MIDI_TOKEN = "<|midi_pad|>" def setup(model_id: str, device: str = None): """Download model, reconstruct tokenizer, load everything.""" print(f"Downloading model: {model_id}") temp_dir = Path(tempfile.mkdtemp()) snapshot_download(repo_id=model_id, local_dir=str(temp_dir)) print(f" Downloaded to: {temp_dir}") # Read model config with open(temp_dir / "config.json") as f: cfg = json.load(f) expanded_vocab = cfg["vocab_size"] # e.g. 152188 print(f" Model vocab_size: {expanded_vocab}") # Load base Qwen3 tokenizer print("Loading base Qwen3 tokenizer...") tokenizer = AutoTokenizer.from_pretrained( "Qwen/Qwen3-0.6B", trust_remote_code=True ) base_vocab = len(tokenizer) print(f" Base vocab: {base_vocab}") # Infer added tokens n_added = expanded_vocab - base_vocab n_special = 3 midi_vocab_size = n_added - n_special print(f" Added tokens: {n_added} (special={n_special}, midi={midi_vocab_size})") # Add tokens to tokenizer midi_tokens = [f"<|midi_{i}|>" for i in range(midi_vocab_size)] tokenizer.add_tokens([BOS_MIDI_TOKEN, EOS_MIDI_TOKEN, PAD_MIDI_TOKEN] + midi_tokens) print(f" Expanded tokenizer: {len(tokenizer)} tokens") # Load model print("Loading model...") dtype = torch.bfloat16 if device == "cuda" else torch.float32 model = AutoModelForCausalLM.from_pretrained( str(temp_dir), trust_remote_code=True, torch_dtype=dtype, device_map="auto" if device == "cuda" else None, ) if device != "cuda": model = model.to(device) model.eval() print(f" Model on device: {model.device}") # Load MidiTok tokenizer — find tokenizer.json in the model repo # MidiTok expects a FILE path, not a directory midi_tok_dir = temp_dir / "midi_tokenizer_init" midi_tokenizer = None if midi_tok_dir.is_dir(): # MidiTok saves as tokenizer.json inside the folder tok_json = midi_tok_dir / "tokenizer.json" if tok_json.exists(): print(f" Found MidiTok tokenizer.json: {tok_json}") midi_tokenizer = REMI(params=str(tok_json)) else: # Fallback: find any .json file in the directory json_files = sorted(midi_tok_dir.glob("*.json")) if json_files: print(f" Found MidiTok config: {json_files[0]}") midi_tokenizer = REMI(params=str(json_files[0])) else: print(f" WARNING: No .json found in {midi_tok_dir}, files: {list(midi_tok_dir.iterdir())}") else: print(f" WARNING: midi_tokenizer_init/ not found at {midi_tok_dir}") if midi_tokenizer is None: print(" Creating default MidiTok REMI tokenizer") tok_cfg = TokenizerConfig( num_velocities=16, use_chords=True, use_tempos=True, use_time_signatures=True, use_programs=True, num_programs=128, ) midi_tokenizer = REMI(tok_cfg) print(f" MidiTok vocab: {midi_tokenizer.vocab_size}") # Build metadata dict bos_id = tokenizer.convert_tokens_to_ids(BOS_MIDI_TOKEN) eos_id = tokenizer.convert_tokens_to_ids(EOS_MIDI_TOKEN) pad_id = tokenizer.convert_tokens_to_ids(PAD_MIDI_TOKEN) midi_offset = base_vocab + n_special metadata = { "base_vocab": base_vocab, "midi_vocab_size": midi_vocab_size, "midi_offset": midi_offset, "bos_id": bos_id, "eos_id": eos_id, "pad_id": pad_id, "max_length": 2048, } print(f" Metadata: midi_offset={midi_offset}, bos={bos_id}, eos={eos_id}") return model, tokenizer, midi_tokenizer, metadata, temp_dir def generate(model, tokenizer, midi_tokenizer, metadata, prompt, max_midi_tokens, temperature, top_k, top_p, device): """Generate MIDI from text prompt.""" bos_id = metadata["bos_id"] eos_id = metadata["eos_id"] pad_id = metadata["pad_id"] midi_offset = metadata["midi_offset"] midi_vocab_size = metadata["midi_vocab_size"] max_length = metadata["max_length"] # Tokenize prompt text_ids = tokenizer.encode(prompt, add_special_tokens=False) print(f"\nPrompt: '{prompt[:200]}...'" if len(prompt) > 200 else f"\nPrompt: '{prompt}'") print(f"Text tokens: {len(text_ids)}") input_ids = text_ids + [bos_id] generated = input_ids.copy() print(f"Generating up to {max_midi_tokens} MIDI tokens...") with torch.no_grad(): for i in range(max_midi_tokens): if len(generated) >= max_length: print(f" Max length reached ({max_length})") break inp = torch.tensor([generated], dtype=torch.long, device=device) logits = model(inp).logits[:, -1, :] logits = logits / temperature # Top-k if top_k > 0: idx_rm = logits < torch.topk(logits, min(top_k, logits.size(-1)))[0][..., -1, None] logits[idx_rm] = float("-inf") # Top-p if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumsum = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) rm = cumsum > top_p rm[..., 1:] = rm[..., :-1].clone() rm[..., 0] = False rm = rm.scatter(-1, sorted_indices, rm) logits[rm] = float("-inf") probs = torch.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1).item() generated.append(next_tok) if next_tok == eos_id: print(f" EOS at step {i+1}") break if next_tok == pad_id: print(f" PAD at step {i+1}") break # Extract MIDI tokens try: bos_idx = generated.index(bos_id) except ValueError: print("ERROR: BOS token not found!") return None raw = generated[bos_idx + 1:] raw = [t for t in raw if t not in (eos_id, pad_id)] midi_ids = [t - midi_offset for t in raw if t >= midi_offset] midi_ids = [t for t in midi_ids if 0 <= t < midi_vocab_size] print(f"Extracted {len(midi_ids)} MIDI tokens") if not midi_ids: print("ERROR: No valid MIDI tokens!") return None midi = midi_tokenizer.decode(midi_ids) return midi def save_midi(midi_score, output_path: str): """Save a MidiTok decoded score to a MIDI file. MidiTok v3 returns a symusic ScoreTick object which uses dump_midi() instead of the old miditoolkit MidiFile.dump() method. """ out = Path(output_path) out.parent.mkdir(parents=True, exist_ok=True) # symusic ScoreTick (MidiTok v3) uses dump_midi() if hasattr(midi_score, "dump_midi"): midi_score.dump_midi(str(out)) # Fallback for older miditoolkit MidiFile objects elif hasattr(midi_score, "dump"): midi_score.dump(str(out)) else: raise AttributeError( f"Cannot save MIDI object of type {type(midi_score)}. " f"Available attrs: {[a for a in dir(midi_score) if not a.startswith('_')]}" ) size = out.stat().st_size print(f"Saved: {out} ({size} bytes)") def main(): parser = argparse.ArgumentParser() parser.add_argument("--model_id", type=str, default="rahuldshetty/midi-qwen3-v1") parser.add_argument("--prompt", type=str, default=None) parser.add_argument("--dataset_prompt", action="store_true") parser.add_argument("--output_path", type=str, default="generated.mid") parser.add_argument("--max_midi_tokens", type=int, default=1024) parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_k", type=int, default=50) parser.add_argument("--top_p", type=float, default=0.92) parser.add_argument("--num_samples", type=int, default=1) parser.add_argument("--device", type=str, default=None) parser.add_argument("--seed", type=int, default=None) args = parser.parse_args() device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") if args.seed is not None: torch.manual_seed(args.seed) model, tokenizer, midi_tokenizer, metadata, temp_dir = setup(args.model_id, device) # Get prompt if args.dataset_prompt: ds = load_dataset("rahuldshetty/midi-generation-dataset", split="train") import random sample = ds[random.randint(0, len(ds) - 1)] prompt = sample["prompt"] print(f"\nUsing dataset prompt (first 200 chars):\n{prompt[:200]}...") elif args.prompt: prompt = args.prompt else: prompt = "A cheerful piano piece in C major, 120 BPM, classical style" # Generate for i in range(args.num_samples): out = Path(args.output_path) if args.num_samples > 1: out = out.with_name(f"{out.stem}_{i+1}{out.suffix}") midi = generate( model, tokenizer, midi_tokenizer, metadata, prompt, args.max_midi_tokens, args.temperature, args.top_k, args.top_p, device, ) if midi: save_midi(midi, str(out)) else: print(f"Sample {i+1} failed") shutil.rmtree(temp_dir, ignore_errors=True) print("\nDone!") if __name__ == "__main__": main()