// Copyright 2025 The Google AI Edge Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use antlr4rust::common_token_stream::CommonTokenStream; use antlr4rust::error_strategy::BailErrorStrategy; use antlr4rust::tree::{ParseTree, ParseTreeListener}; use antlr4rust::InputStream; use antlr_fc_tool_call_parser::{antlrfclexer, antlrfcparser, antlrfcparserlistener}; use antlrfclexer::AntlrFcLexer; use antlrfcparser::{ AntlrFcParser, AntlrFcParserContextType, AntlrFcParserTreeWalker, ArrayContext, ArrayContextAttrs, FunctionCallContext, FunctionCallContextAttrs, ObjectContext, ObjectContextAttrs, PairContextAttrs, ValueContext, ValueContextAttrs, }; use antlrfcparserlistener::AntlrFcParserListener; use serde_json::{json, Map, Value}; use std::collections::HashSet; fn strip_escape_tokens(text: &str) -> &str { let mut s = text; const ESCAPE: &str = ""; if s.starts_with(ESCAPE) { s = &s[ESCAPE.len()..]; } if s.ends_with(ESCAPE) { s = &s[..(s.len() - ESCAPE.len())]; } s } fn parse_value(value_ctx: &ValueContext) -> Result { if let Some(escaped_string_ctx) = value_ctx.ESCAPED_STRING() { Ok(json!(strip_escape_tokens(&escaped_string_ctx.get_text()))) } else if let Some(number_ctx) = value_ctx.NUMBER() { let text = number_ctx.get_text(); if let Ok(double_val) = text.parse::() { Ok(json!(double_val)) } else { Err(format!("Failed to parse number: {}", text)) } } else if let Some(object_ctx) = value_ctx.object() { parse_object(&object_ctx) } else if let Some(array_ctx) = value_ctx.array() { parse_array(&array_ctx) } else if let Some(boolean_ctx) = value_ctx.BOOLEAN() { Ok(json!(boolean_ctx.get_text() == "true")) } else if let Some(_null_literal_ctx) = value_ctx.NULL_LITERAL() { Ok(Value::Null) } else { Err(format!("Unhandled value type: {}", value_ctx.get_text())) } } fn parse_array(array_ctx: &ArrayContext) -> Result { let mut list = vec![]; for value in array_ctx.value_all() { let parsed_value = parse_value(&value)?; list.push(parsed_value); } Ok(Value::Array(list)) } fn parse_object(object_ctx: &ObjectContext) -> Result { let mut map = Map::new(); let mut seen_keys = HashSet::new(); for pair_ctx in object_ctx.pair_all() { let id_token = pair_ctx.ID().ok_or_else(|| "Invalid pair in object: ID missing".to_string())?; let value_ctx = pair_ctx.value().ok_or_else(|| "Invalid pair in object: Value missing".to_string())?; let key = id_token.get_text(); if key.is_empty() { return Err("Object key is empty".to_string()); } if seen_keys.contains(&key) { // Log duplicate key but don't treat it as an error. eprintln!("Ignoring duplicate key: {}", key); continue; } seen_keys.insert(key.clone()); let parsed_value = parse_value(&value_ctx) .map_err(|e| format!("Error parsing value for key '{}': {}", key, e))?; map.insert(key, parsed_value); } Ok(Value::Object(map)) } struct FcListener { tool_calls: Result, String>, } impl FcListener { fn new() -> Self { FcListener { tool_calls: Ok(Vec::new()) } } fn tool_calls(self) -> Result, String> { self.tool_calls } } impl<'input> ParseTreeListener<'input, AntlrFcParserContextType> for FcListener {} impl<'input> AntlrFcParserListener<'input> for FcListener { fn enter_functionCall(&mut self, ctx: &FunctionCallContext<'input>) { if let Ok(tool_calls) = &mut self.tool_calls { let name = if let Some(id_token) = ctx.ID() { id_token.get_text() } else { "".to_string() }; let mut tool_call_map = Map::new(); tool_call_map.insert("name".to_string(), json!(name)); if let Some(object_ctx) = ctx.object() { match parse_object(&object_ctx) { Ok(args) => { tool_call_map.insert("arguments".to_string(), args); } Err(e) => { self.tool_calls = Err(e); return; } } } else { tool_call_map.insert("arguments".to_string(), json!({})); } tool_calls.push(Value::Object(tool_call_map)); } } } pub fn parse_fc_expression(text: &str) -> Result, String> { if text.len() == 0 { return Ok(Vec::new()); } let lexer = AntlrFcLexer::new(InputStream::new(text)); let mut parser = AntlrFcParser::with_strategy( CommonTokenStream::new(lexer), Box::new(BailErrorStrategy::new()), ); let start = match parser.start() { Ok(start) => start, Err(e) => return Err(e.to_string()), }; match AntlrFcParserTreeWalker::walk(Box::new(FcListener::new()), start.as_ref()) { Ok(listener) => listener.tool_calls(), Err(e) => Err(e.to_string()), } }