| | """ |
| | constrained_generator.py - JSON Schema Constrained Generation |
| | |
| | This implements constrained decoding to force valid JSON output: |
| | 1. Token-by-token validation against JSON schema |
| | 2. Backtracking on invalid JSON syntax |
| | 3. Beam search with JSON constraints |
| | 4. Schema-aware generation |
| | """ |
| |
|
| | import torch |
| | import json |
| | import jsonschema |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | from typing import List, Dict, Any, Optional |
| | import re |
| |
|
| | class ConstrainedJSONGenerator: |
| | def __init__(self, model, tokenizer, device="mps"): |
| | self.model = model |
| | self.tokenizer = tokenizer |
| | self.device = device |
| | self.model.eval() |
| | |
| | def is_valid_json_prefix(self, text: str) -> bool: |
| | """Check if text could be the start of valid JSON.""" |
| | text = text.strip() |
| | if not text: |
| | return True |
| | |
| | |
| | if not text.startswith('{'): |
| | return False |
| | |
| | |
| | try: |
| | json.loads(text) |
| | return True |
| | except json.JSONDecodeError as e: |
| | |
| | if "Expecting" in str(e) and "delimiter" in str(e): |
| | |
| | return True |
| | return False |
| | |
| | def get_valid_next_tokens(self, current_text: str, schema: Dict) -> List[int]: |
| | """Get tokens that would keep JSON valid.""" |
| | valid_tokens = [] |
| | |
| | |
| | vocab_size = len(self.tokenizer.vocab) |
| | |
| | for token_id in range(vocab_size): |
| | if token_id == self.tokenizer.pad_token_id: |
| | continue |
| | |
| | token_text = self.tokenizer.decode([token_id]) |
| | new_text = current_text + token_text |
| | |
| | if self.is_valid_json_prefix(new_text): |
| | valid_tokens.append(token_id) |
| | |
| | |
| | if len(valid_tokens) > 50: |
| | break |
| | |
| | return valid_tokens |
| | |
| | def generate_constrained(self, prompt: str, schema: Dict, max_length: int = 200) -> str: |
| | """Generate text with JSON constraints.""" |
| | |
| | inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
| | |
| | generated_text = "" |
| | current_input_ids = inputs['input_ids'].clone() |
| | |
| | for step in range(max_length): |
| | |
| | with torch.no_grad(): |
| | outputs = self.model(current_input_ids) |
| | logits = outputs.logits[0, -1, :] |
| | |
| | |
| | valid_tokens = self.get_valid_next_tokens(generated_text, schema) |
| | |
| | if not valid_tokens: |
| | |
| | if not generated_text.strip().endswith('}'): |
| | |
| | next_token_id = self.tokenizer.encode('}')[0] |
| | else: |
| | break |
| | else: |
| | |
| | masked_logits = logits.clone() |
| | mask = torch.full_like(logits, float('-inf')) |
| | mask[valid_tokens] = 0 |
| | masked_logits = masked_logits + mask |
| | |
| | |
| | probs = torch.softmax(masked_logits, dim=-1) |
| | next_token_id = torch.multinomial(probs, 1).item() |
| | |
| | |
| | current_input_ids = torch.cat([ |
| | current_input_ids, |
| | torch.tensor([[next_token_id]], device=self.device) |
| | ], dim=1) |
| | |
| | |
| | new_token = self.tokenizer.decode([next_token_id]) |
| | generated_text += new_token |
| | |
| | |
| | try: |
| | parsed = json.loads(generated_text.strip()) |
| | if self.validate_against_schema(parsed, schema): |
| | break |
| | except: |
| | continue |
| | |
| | return generated_text.strip() |
| | |
| | def validate_against_schema(self, data: Dict, schema: Dict) -> bool: |
| | """Validate JSON data against schema.""" |
| | try: |
| | jsonschema.validate(data, schema) |
| | return True |
| | except jsonschema.ValidationError: |
| | return False |
| | |
| | def generate_with_beam_search(self, prompt: str, schema: Dict, num_beams: int = 3) -> str: |
| | """Generate with beam search and JSON constraints.""" |
| | inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = self.model.generate( |
| | **inputs, |
| | max_new_tokens=150, |
| | num_beams=num_beams, |
| | early_stopping=True, |
| | temperature=0.1, |
| | do_sample=False, |
| | pad_token_id=self.tokenizer.eos_token_id, |
| | num_return_sequences=num_beams |
| | ) |
| | |
| | |
| | candidates = [] |
| | for output in outputs: |
| | generated_text = self.tokenizer.decode( |
| | output[inputs['input_ids'].shape[1]:], |
| | skip_special_tokens=True |
| | ) |
| | candidates.append(generated_text.strip()) |
| | |
| | |
| | for candidate in candidates: |
| | try: |
| | parsed = json.loads(candidate) |
| | if self.validate_against_schema(parsed, schema): |
| | return candidate |
| | except json.JSONDecodeError: |
| | continue |
| | |
| | |
| | return candidates[0] if candidates else "" |
| |
|
| | def create_json_schema_from_function(function_def: Dict) -> Dict: |
| | """Create a JSON schema for validating function calls.""" |
| | return { |
| | "type": "object", |
| | "properties": { |
| | "name": { |
| | "type": "string", |
| | "const": function_def["name"] |
| | }, |
| | "arguments": function_def["parameters"] |
| | }, |
| | "required": ["name", "arguments"], |
| | "additionalProperties": False |
| | } |
| |
|
| | def test_constrained_generation(): |
| | """Test the constrained generator.""" |
| | print("π§ͺ Testing Constrained JSON Generation...") |
| | |
| | |
| | model_name = "HuggingFaceTB/SmolLM3-3B" |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_name, |
| | torch_dtype=torch.float32, |
| | device_map="mps" if torch.backends.mps.is_available() else "auto" |
| | ) |
| | |
| | generator = ConstrainedJSONGenerator(model, tokenizer) |
| | |
| | |
| | function_def = { |
| | "name": "get_weather", |
| | "description": "Get weather forecast", |
| | "parameters": { |
| | "type": "object", |
| | "properties": { |
| | "location": {"type": "string"}, |
| | "days": {"type": "integer"} |
| | }, |
| | "required": ["location", "days"] |
| | } |
| | } |
| | |
| | schema = create_json_schema_from_function(function_def) |
| | |
| | prompt = f"""<|im_start|>system |
| | You are a helpful assistant that calls functions by responding with valid JSON when given a schema. Always respond with JSON function calls only, never prose.<|im_end|> |
| | |
| | <schema> |
| | {json.dumps(function_def, indent=2)} |
| | </schema> |
| | |
| | <|im_start|>user |
| | Get 3-day weather for New York<|im_end|> |
| | <|im_start|>assistant |
| | """ |
| | |
| | |
| | print("π― Testing constrained generation...") |
| | result = generator.generate_constrained(prompt, schema) |
| | print(f"π€ Constrained result: {result}") |
| | |
| | |
| | try: |
| | parsed = json.loads(result) |
| | generator.validate_against_schema(parsed, schema) |
| | print("β
Valid JSON with correct schema!") |
| | except Exception as e: |
| | print(f"β Validation failed: {e}") |
| | |
| | |
| | print("π― Testing beam search...") |
| | beam_result = generator.generate_with_beam_search(prompt, schema) |
| | print(f"π€ Beam result: {beam_result}") |
| | |
| | try: |
| | parsed = json.loads(beam_result) |
| | generator.validate_against_schema(parsed, schema) |
| | print("β
Beam search produced valid JSON!") |
| | except Exception as e: |
| | print(f"β Beam validation failed: {e}") |
| |
|
| | if __name__ == "__main__": |
| | test_constrained_generation() |