File size: 5,983 Bytes
5f923cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
// Copyright 2025 The Google AI Edge Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "runtime/components/tool_use/parser_utils.h"

#include <string>
#include <vector>

#include "absl/base/no_destructor.h"  // from @com_google_absl
#include "absl/container/flat_hash_map.h"  // from @com_google_absl
#include "absl/status/status.h"  // from @com_google_absl
#include "absl/status/statusor.h"  // from @com_google_absl
#include "absl/strings/str_cat.h"  // from @com_google_absl
#include "absl/strings/str_join.h"  // from @com_google_absl
#include "absl/strings/str_split.h"  // from @com_google_absl
#include "absl/strings/string_view.h"  // from @com_google_absl
#include "nlohmann/json.hpp"  // from @nlohmann_json
#include "runtime/components/tool_use/fc_parser_utils.h"
#include "runtime/components/tool_use/json_parser_utils.h"
#include "runtime/components/tool_use/python_parser_utils.h"
#include "re2/re2.h"  // from @com_googlesource_code_re2

namespace litert::lm {

namespace {

RE2 TextAndToolCodeRegex(absl::string_view code_fence_start,
                         absl::string_view code_fence_end,
                         bool escape_fence_strings) {
  // Construct the regex pattern: (non-greedy text before) <start> (non-greedy
  // code) <end>.
  std::string pattern;
  if (escape_fence_strings) {
    // QuoteMeta escapes any special regex characters in the fence strings.
    pattern = absl::StrCat("(?ms)(.*?)", RE2::QuoteMeta(code_fence_start),
                           "(.*?)", RE2::QuoteMeta(code_fence_end));
  } else {
    pattern =
        absl::StrCat("(?ms)(.*?)", code_fence_start, "(.*?)", code_fence_end);
  }
  return RE2(pattern);
}

std::string FilterLines(absl::string_view input, const RE2& regex) {
  std::vector<absl::string_view> lines = absl::StrSplit(input, '\n');
  std::string captured_part;
  std::vector<std::string> captured_lines;

  for (absl::string_view line : lines) {
    if (RE2::PartialMatch(line, regex, &captured_part)) {
      captured_lines.push_back(captured_part);
    } else {
      captured_lines.push_back(std::string(line));
    }
  }

  return absl::StrJoin(captured_lines, "\n");
}

}  // namespace

SyntaxType GetSyntaxType(absl::string_view syntax_type) {
  static const absl::NoDestructor<
      absl::flat_hash_map<absl::string_view, SyntaxType>>
      kStringToSyntaxType({
          {"python", SyntaxType::kPython},
          {"json", SyntaxType::kJson},
          {"fc", SyntaxType::kFc},
      });
  auto it = kStringToSyntaxType->find(syntax_type);
  if (it == kStringToSyntaxType->end()) {
    return SyntaxType::kUnknown;
  }
  return it->second;
}

absl::StatusOr<nlohmann::ordered_json> ParseTextAndToolCalls(
    absl::string_view response_str, absl::string_view code_fence_start,
    absl::string_view code_fence_end, SyntaxType syntax_type,
    bool escape_fence_strings, absl::string_view tool_code_regex) {
  nlohmann::ordered_json result = nlohmann::json::object();
  // If the response is empty, return a content array with a single empty text
  // element to ensure the output format is consistent.
  if (response_str.empty()) {
    result["content"].push_back({{"type", "text"}, {"text", ""}});
    return result;
  }
  RE2 regex = TextAndToolCodeRegex(code_fence_start, code_fence_end,
                                   escape_fence_strings);
  if (!regex.ok()) {
    return absl::InvalidArgumentError(absl::StrCat(
        "Invalid regex: ", regex.pattern(), " error: ", regex.error()));
  }

  std::string text;
  std::string code_block;
  absl::string_view original_response_str = response_str;
  while (RE2::Consume(&response_str, regex, &text, &code_block)) {
    // Append text to the content array.
    if (!text.empty()) {
      result["content"].push_back({{"type", "text"}, {"text", text}});
    }

    // Before parsing the code block, apply tool_code_regex to each line.
    if (!tool_code_regex.empty()) {
      RE2 regex(tool_code_regex);
      if (!regex.ok()) {
        return absl::InvalidArgumentError(
            absl::StrCat("Invalid tool_code_regex: ", tool_code_regex));
      }
      code_block = FilterLines(code_block, regex);
    }

    // Parse tool calls from the code block.
    if (!code_block.empty()) {
      absl::StatusOr<nlohmann::ordered_json> tool_calls;
      if (syntax_type == SyntaxType::kPython) {
        tool_calls = ParsePythonExpression(code_block);
      } else if (syntax_type == SyntaxType::kJson) {
        tool_calls = ParseJsonExpression(code_block);
      } else if (syntax_type == SyntaxType::kFc) {
        tool_calls = ParseFcExpression(code_block);
      } else {
        return absl::InvalidArgumentError(absl::StrCat(
            "Unsupported syntax type: ", static_cast<int>(syntax_type)));
      }
      if (!tool_calls.ok()) {
        return absl::InvalidArgumentError(absl::StrCat(
            "Failed to parse tool calls from response: ", original_response_str,
            "code block: ", code_block,
            " with error: ", tool_calls.status().message()));
      }
      for (const auto& tool_call : *tool_calls) {
        result["tool_calls"].push_back(
            {{"type", "function"}, {"function", tool_call}});
      }
    }
    text.clear();
    code_block.clear();
  }

  // Append the remaining text to the content array.
  if (!response_str.empty()) {
    result["content"].push_back({{"type", "text"}, {"text", response_str}});
  }

  return result;
}

}  // namespace litert::lm