import argparse import json import re from pathlib import Path import torch from jsonschema import Draft7Validator from peft import AutoPeftModelForCausalLM from transformers import AutoModelForCausalLM, AutoTokenizer SYSTEM_PREFIX = ( "You are GravityLLM, a Spatial9 scene generation model. " "Given music constraints and stem features, output ONLY valid Spatial9Scene JSON. " "Do not return markdown. Do not explain your answer.\n\n" ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run GravityLLM inference on a Spatial9 constraint payload.") parser.add_argument("--model_dir", type=str, required=True, help="Path or Hub repo id for trained model or adapter.") parser.add_argument("--input_json", type=Path, required=True) parser.add_argument("--schema_path", type=Path, default=Path("schemas/scene.schema.json")) parser.add_argument("--output_json", type=Path, default=None) parser.add_argument("--max_new_tokens", type=int, default=900) parser.add_argument("--temperature", type=float, default=0.35) parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--validate", action="store_true") return parser.parse_args() def load_model_and_tokenizer(model_dir: str): tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = None try: model = AutoPeftModelForCausalLM.from_pretrained( model_dir, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True, ) except Exception: model = AutoModelForCausalLM.from_pretrained( model_dir, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True, ) model.eval() return model, tokenizer def extract_first_json(text: str) -> str: match = re.search(r"\{.*\}", text, flags=re.DOTALL) return match.group(0).strip() if match else text.strip() def validate_output(schema_path: Path, output_text: str): schema = json.loads(schema_path.read_text(encoding="utf-8")) data = json.loads(output_text) validator = Draft7Validator(schema) errors = sorted(validator.iter_errors(data), key=lambda e: list(e.path)) return data, errors def main() -> None: args = parse_args() payload = json.loads(args.input_json.read_text(encoding="utf-8")) model, tokenizer = load_model_and_tokenizer(args.model_dir) prompt = SYSTEM_PREFIX + "INPUT:\n" + json.dumps(payload, ensure_ascii=False, indent=2) + "\n\nOUTPUT:\n" inputs = tokenizer(prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=args.max_new_tokens, do_sample=True, temperature=args.temperature, top_p=args.top_p, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, ) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) prompt_prefix = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True) raw_completion = decoded[len(prompt_prefix):].strip() json_text = extract_first_json(raw_completion) if args.validate: try: _, errors = validate_output(args.schema_path, json_text) if errors: print("Validation: INVALID") for err in errors[:20]: path = ".".join(str(p) for p in err.path) print(f"- {path}: {err.message}") else: print("Validation: VALID") except Exception as exc: print(f"Validation failed: {exc}") if args.output_json: args.output_json.parent.mkdir(parents=True, exist_ok=True) args.output_json.write_text(json_text + "\n", encoding="utf-8") print(f"Wrote output to {args.output_json}") print(json_text) if __name__ == "__main__": main()