midi-generation-scripts / inference_trained_model.py
rahuldshetty's picture
Upload inference_trained_model.py
9c1b10f verified
"""
Inference script specifically for the trained Qwen3 MIDI generation model.
Auto-detects model structure from checkpoint — no metadata file needed.
Works with: https://huggingface.co/rahuldshetty/midi-qwen3-v1
Usage:
python inference_trained_model.py --prompt "A dark electronic piece in D minor, 140 BPM"
python inference_trained_model.py --dataset_prompt --num_samples 3
python inference_trained_model.py --prompt "Jazz piano in C major, 120 BPM" --temperature 0.8 --max_midi_tokens 512
"""
import argparse
import json
import tempfile
import shutil
from pathlib import Path
import torch
from datasets import load_dataset
from huggingface_hub import snapshot_download
from miditok import REMI, TokenizerConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
BOS_MIDI_TOKEN = "<|midi_start|>"
EOS_MIDI_TOKEN = "<|midi_end|>"
PAD_MIDI_TOKEN = "<|midi_pad|>"
def setup(model_id: str, device: str = None):
"""Download model, reconstruct tokenizer, load everything."""
print(f"Downloading model: {model_id}")
temp_dir = Path(tempfile.mkdtemp())
snapshot_download(repo_id=model_id, local_dir=str(temp_dir))
print(f" Downloaded to: {temp_dir}")
# Read model config
with open(temp_dir / "config.json") as f:
cfg = json.load(f)
expanded_vocab = cfg["vocab_size"] # e.g. 152188
print(f" Model vocab_size: {expanded_vocab}")
# Load base Qwen3 tokenizer
print("Loading base Qwen3 tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-0.6B", trust_remote_code=True
)
base_vocab = len(tokenizer)
print(f" Base vocab: {base_vocab}")
# Infer added tokens
n_added = expanded_vocab - base_vocab
n_special = 3
midi_vocab_size = n_added - n_special
print(f" Added tokens: {n_added} (special={n_special}, midi={midi_vocab_size})")
# Add tokens to tokenizer
midi_tokens = [f"<|midi_{i}|>" for i in range(midi_vocab_size)]
tokenizer.add_tokens([BOS_MIDI_TOKEN, EOS_MIDI_TOKEN, PAD_MIDI_TOKEN] + midi_tokens)
print(f" Expanded tokenizer: {len(tokenizer)} tokens")
# Load model
print("Loading model...")
dtype = torch.bfloat16 if device == "cuda" else torch.float32
model = AutoModelForCausalLM.from_pretrained(
str(temp_dir),
trust_remote_code=True,
torch_dtype=dtype,
device_map="auto" if device == "cuda" else None,
)
if device != "cuda":
model = model.to(device)
model.eval()
print(f" Model on device: {model.device}")
# Load MidiTok tokenizer — find tokenizer.json in the model repo
# MidiTok expects a FILE path, not a directory
midi_tok_dir = temp_dir / "midi_tokenizer_init"
midi_tokenizer = None
if midi_tok_dir.is_dir():
# MidiTok saves as tokenizer.json inside the folder
tok_json = midi_tok_dir / "tokenizer.json"
if tok_json.exists():
print(f" Found MidiTok tokenizer.json: {tok_json}")
midi_tokenizer = REMI(params=str(tok_json))
else:
# Fallback: find any .json file in the directory
json_files = sorted(midi_tok_dir.glob("*.json"))
if json_files:
print(f" Found MidiTok config: {json_files[0]}")
midi_tokenizer = REMI(params=str(json_files[0]))
else:
print(f" WARNING: No .json found in {midi_tok_dir}, files: {list(midi_tok_dir.iterdir())}")
else:
print(f" WARNING: midi_tokenizer_init/ not found at {midi_tok_dir}")
if midi_tokenizer is None:
print(" Creating default MidiTok REMI tokenizer")
tok_cfg = TokenizerConfig(
num_velocities=16, use_chords=True, use_tempos=True,
use_time_signatures=True, use_programs=True, num_programs=128,
)
midi_tokenizer = REMI(tok_cfg)
print(f" MidiTok vocab: {midi_tokenizer.vocab_size}")
# Build metadata dict
bos_id = tokenizer.convert_tokens_to_ids(BOS_MIDI_TOKEN)
eos_id = tokenizer.convert_tokens_to_ids(EOS_MIDI_TOKEN)
pad_id = tokenizer.convert_tokens_to_ids(PAD_MIDI_TOKEN)
midi_offset = base_vocab + n_special
metadata = {
"base_vocab": base_vocab,
"midi_vocab_size": midi_vocab_size,
"midi_offset": midi_offset,
"bos_id": bos_id,
"eos_id": eos_id,
"pad_id": pad_id,
"max_length": 2048,
}
print(f" Metadata: midi_offset={midi_offset}, bos={bos_id}, eos={eos_id}")
return model, tokenizer, midi_tokenizer, metadata, temp_dir
def generate(model, tokenizer, midi_tokenizer, metadata, prompt, max_midi_tokens,
temperature, top_k, top_p, device):
"""Generate MIDI from text prompt."""
bos_id = metadata["bos_id"]
eos_id = metadata["eos_id"]
pad_id = metadata["pad_id"]
midi_offset = metadata["midi_offset"]
midi_vocab_size = metadata["midi_vocab_size"]
max_length = metadata["max_length"]
# Tokenize prompt
text_ids = tokenizer.encode(prompt, add_special_tokens=False)
print(f"\nPrompt: '{prompt[:200]}...'" if len(prompt) > 200 else f"\nPrompt: '{prompt}'")
print(f"Text tokens: {len(text_ids)}")
input_ids = text_ids + [bos_id]
generated = input_ids.copy()
print(f"Generating up to {max_midi_tokens} MIDI tokens...")
with torch.no_grad():
for i in range(max_midi_tokens):
if len(generated) >= max_length:
print(f" Max length reached ({max_length})")
break
inp = torch.tensor([generated], dtype=torch.long, device=device)
logits = model(inp).logits[:, -1, :]
logits = logits / temperature
# Top-k
if top_k > 0:
idx_rm = logits < torch.topk(logits, min(top_k, logits.size(-1)))[0][..., -1, None]
logits[idx_rm] = float("-inf")
# Top-p
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumsum = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
rm = cumsum > top_p
rm[..., 1:] = rm[..., :-1].clone()
rm[..., 0] = False
rm = rm.scatter(-1, sorted_indices, rm)
logits[rm] = float("-inf")
probs = torch.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1).item()
generated.append(next_tok)
if next_tok == eos_id:
print(f" EOS at step {i+1}")
break
if next_tok == pad_id:
print(f" PAD at step {i+1}")
break
# Extract MIDI tokens
try:
bos_idx = generated.index(bos_id)
except ValueError:
print("ERROR: BOS token not found!")
return None
raw = generated[bos_idx + 1:]
raw = [t for t in raw if t not in (eos_id, pad_id)]
midi_ids = [t - midi_offset for t in raw if t >= midi_offset]
midi_ids = [t for t in midi_ids if 0 <= t < midi_vocab_size]
print(f"Extracted {len(midi_ids)} MIDI tokens")
if not midi_ids:
print("ERROR: No valid MIDI tokens!")
return None
midi = midi_tokenizer.decode(midi_ids)
return midi
def save_midi(midi_score, output_path: str):
"""Save a MidiTok decoded score to a MIDI file.
MidiTok v3 returns a symusic ScoreTick object which uses dump_midi()
instead of the old miditoolkit MidiFile.dump() method.
"""
out = Path(output_path)
out.parent.mkdir(parents=True, exist_ok=True)
# symusic ScoreTick (MidiTok v3) uses dump_midi()
if hasattr(midi_score, "dump_midi"):
midi_score.dump_midi(str(out))
# Fallback for older miditoolkit MidiFile objects
elif hasattr(midi_score, "dump"):
midi_score.dump(str(out))
else:
raise AttributeError(
f"Cannot save MIDI object of type {type(midi_score)}. "
f"Available attrs: {[a for a in dir(midi_score) if not a.startswith('_')]}"
)
size = out.stat().st_size
print(f"Saved: {out} ({size} bytes)")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default="rahuldshetty/midi-qwen3-v1")
parser.add_argument("--prompt", type=str, default=None)
parser.add_argument("--dataset_prompt", action="store_true")
parser.add_argument("--output_path", type=str, default="generated.mid")
parser.add_argument("--max_midi_tokens", type=int, default=1024)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=50)
parser.add_argument("--top_p", type=float, default=0.92)
parser.add_argument("--num_samples", type=int, default=1)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--seed", type=int, default=None)
args = parser.parse_args()
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if args.seed is not None:
torch.manual_seed(args.seed)
model, tokenizer, midi_tokenizer, metadata, temp_dir = setup(args.model_id, device)
# Get prompt
if args.dataset_prompt:
ds = load_dataset("rahuldshetty/midi-generation-dataset", split="train")
import random
sample = ds[random.randint(0, len(ds) - 1)]
prompt = sample["prompt"]
print(f"\nUsing dataset prompt (first 200 chars):\n{prompt[:200]}...")
elif args.prompt:
prompt = args.prompt
else:
prompt = "A cheerful piano piece in C major, 120 BPM, classical style"
# Generate
for i in range(args.num_samples):
out = Path(args.output_path)
if args.num_samples > 1:
out = out.with_name(f"{out.stem}_{i+1}{out.suffix}")
midi = generate(
model, tokenizer, midi_tokenizer, metadata,
prompt, args.max_midi_tokens,
args.temperature, args.top_k, args.top_p, device,
)
if midi:
save_midi(midi, str(out))
else:
print(f"Sample {i+1} failed")
shutil.rmtree(temp_dir, ignore_errors=True)
print("\nDone!")
if __name__ == "__main__":
main()