DragonProgrammer commited on
Commit
46292e7
·
verified ·
1 Parent(s): ce81d3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -72
app.py CHANGED
@@ -58,78 +58,158 @@ def safe_calculator_func(expression: str) -> str:
58
  return f"Error calculating '{expression}': Invalid expression or calculation error ({e})."
59
 
60
  class SlicedLLM(LLM):
61
- def __init__(self, base_llm, max_chars=2048):
62
- self.base_llm = base_llm
63
- self.max_chars = max_chars
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- def _call(self, prompt, **kwargs):
66
- out = self.base_llm._call(prompt, **kwargs)
67
- return out[-self.max_chars:]
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  class LangChainAgentWrapper:
70
  """
71
- Clean, corrected, and hardened LangChain agent wrapper for Gemma-2b-it.
72
- - uses model_id consistently
73
- - safe generation defaults (bfloat16, sensible max_new_tokens)
74
- - compact ReAct prompt including Observation
75
- - no pre-emptive llm.invoke() in __init__
76
- - no ConversationBufferWindowMemory to avoid contaminating ReAct scratchpad
77
  """
78
 
79
- def __init__(self):
 
 
 
 
 
 
80
  print("Initializing LangChainAgentWrapper...")
81
-
82
- # NOTE: Pick the exact model id you intend to load here.
83
- # Keep model_id variable consistent so you don't accidentally load a different model.
84
- model_id = "google/gemma-2b-it"
85
-
86
  try:
87
- # --- Delayed Imports (keeps startup robust) ---
88
  from langchain.agents import AgentExecutor, create_react_agent
89
  from langchain_community.tools import DuckDuckGoSearchRun
90
 
91
- # --- Tokenizer & Model load ---
92
  print(f"Loading tokenizer for: {model_id}")
93
  tokenizer = AutoTokenizer.from_pretrained(model_id)
94
 
95
  print(f"Loading model: {model_id}")
96
- # Use AutoModelForCausalLM for Gemma; bfloat16 for Ada / L40S style cards
97
  model = AutoModelForCausalLM.from_pretrained(
98
  model_id,
99
  torch_dtype=torch.bfloat16,
100
  device_map="auto",
101
- offload_folder="offload", # will offload to disk when needed
102
  )
103
  print("Model loaded successfully.")
104
  print(f"Allocated: {memory_allocated()/1e9:.2f} GB | Reserved: {memory_reserved()/1e9:.2f} GB")
105
 
106
- # --- HuggingFace pipeline with safe defaults ---
107
- # return_full_text=False avoids echoing the whole prompt
108
- # We set a conservative max_new_tokens suitable for ReAct loops and small models
109
  llm_pipeline = transformers.pipeline(
110
  "text-generation",
111
  model=model,
112
  tokenizer=tokenizer,
113
- max_new_tokens=96, # safe default
114
- return_full_text=False, # return only the generated part
115
  pad_token_id=tokenizer.eos_token_id,
116
  eos_token_id=tokenizer.eos_token_id,
117
  )
118
- print("Model pipeline created successfully.")
 
 
 
 
 
119
 
120
- # Wrap the pipeline into the LangChain HuggingFacePipeline LLM wrapper
121
- # (langchain_huggingface.HuggingFacePipeline expects a transformers pipeline)
122
- self.llm = HuggingFacePipeline(pipeline=llm_pipeline)
123
- print("HuggingFacePipeline wrapper created.")
124
 
125
- # --- Initialize Tools ---
126
  print("Defining tools...")
127
  search_tool = DuckDuckGoSearchRun(
128
  name="web_search",
129
- description="A tool that performs a web search using DuckDuckGo. Use this to find up-to-date information about events, facts, or topics when the answer isn't already known."
130
  )
131
 
132
- # Ensure these Tool.name values exactly match the strings in the prompt's action list
133
  self.tools = [
134
  Tool(
135
  name="get_current_time_in_timezone",
@@ -143,57 +223,46 @@ class LangChainAgentWrapper:
143
  description=safe_calculator_func.__doc__
144
  ),
145
  ]
146
- print(f"Tools prepared for agent: {[tool.name for tool in self.tools]}")
147
 
148
- # --- ReAct Prompt (compact + includes Observation) ---
149
- # Important: keep this prompt short and *do not* encourage repetition.
150
- # We include Observation because LangChain inserts the tool result back into the scratchpad.
151
  react_prompt = PromptTemplate(
152
  input_variables=["tools", "tool_names", "agent_scratchpad"],
153
  template="""
154
- DO NOT REPEAT OR PARAPHRASE ANY PART OF THIS PROMPT.
155
-
156
- You are an assistant that strictly follows the ReAct format.
157
-
158
- You can use these tools:
159
- {tools}
160
-
161
- Valid tool names: {tool_names}
162
-
163
- When responding, you MUST follow **this exact output grammar** and NEVER include anything else:
164
-
165
- Thought: <reasoning>
166
- Action: <one of {tool_names} OR "none">
167
- Action Input: <JSON input for the tool>
168
-
169
- (If Action is not "none", the system will provide an Observation.)
170
- (After the Observation, you continue with another Thought/Action loop.)
171
-
172
- If you choose Action: none, you MUST end with:
173
- Final Answer: <your final answer>
174
-
175
- Begin your reasoning now.
176
-
177
- {agent_scratchpad}
178
- Thought:
179
- """
180
- )
181
 
 
 
 
182
 
183
- # --- Create agent & executor ---
 
 
 
 
 
 
 
 
 
184
  print("Creating agent...")
185
  agent = create_react_agent(self.llm, self.tools, react_prompt)
186
 
187
- # NOTE: We intentionally do NOT add a ConversationBufferWindowMemory here.
188
- # ReAct agents benefit from the explicit scratchpad pattern; adding memory can
189
- # sometimes re-introduce context and cause token growth. If you really want memory,
190
- # prefer a tiny window and test thoroughly.
191
  self.agent_executor = AgentExecutor(
192
  agent=agent,
193
  tools=self.tools,
194
  verbose=True,
195
  handle_parsing_errors=True,
196
- max_iterations=2, # you tuned this earlier; 1-2 is best for Gemma-2B
197
  )
198
  print("LangChain agent created successfully.")
199
 
@@ -202,6 +271,22 @@ class LangChainAgentWrapper:
202
  traceback.print_exc()
203
  raise RuntimeError(f"LangChain agent initialization failed: {e}") from e
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  def __call__(self, question: str) -> str:
206
  """
207
  Run the agent on a single question. We rely on the AgentExecutor to manage
 
58
  return f"Error calculating '{expression}': Invalid expression or calculation error ({e})."
59
 
60
  class SlicedLLM(LLM):
61
+ """
62
+ Light wrapper around any LangChain LLM (we'll use the HuggingFacePipeline wrapper).
63
+ Responsibilities:
64
+ - Call the inner LLM
65
+ - Extract text robustly from different return shapes
66
+ - Truncate to `max_chars` from the end (keeps the most recent reasoning)
67
+ - Strip instruction echoing by keeping from the last 'Thought:' if present
68
+ """
69
+ def __init__(self, inner_llm, max_chars: int = 2048, **kwargs):
70
+ self.inner_llm = inner_llm
71
+ self.max_chars = int(max_chars)
72
+ # required for LangChain LLM subclasses
73
+ self.max_retries = kwargs.get("max_retries", 1)
74
+
75
+ @property
76
+ def _llm_type(self) -> str:
77
+ return "sliced-llm"
78
+
79
+ def _call(self, prompt: str, stop=None) -> str:
80
+ """
81
+ Core call entrypoint used by LangChain. We call the inner LLM and then post-process.
82
+ """
83
+ # 1) Call inner LLM (it may expose _call or be callable)
84
+ raw = None
85
+ # inner may be a LangChain LLM (with _call) or a callable pipeline
86
+ if hasattr(self.inner_llm, "_call"):
87
+ raw = self.inner_llm._call(prompt, stop=stop)
88
+ else:
89
+ # fallback - call and try to extract text
90
+ # Many pipeline wrappers accept a string and return text or list
91
+ raw = self.inner_llm(prompt)
92
+
93
+ # 2) Extract text from common return shapes
94
+ text = self._extract_text(raw)
95
+
96
+ # 3) Attempt to remove repeated instruction blocks by finding last 'Thought:' anchor
97
+ # We keep text from the last "Thought:" onward if that appears in the output.
98
+ # This removes prompt-echoed instruction blocks that often appear earlier in the string.
99
+ last_thought_idx = text.rfind("\nThought:")
100
+ if last_thought_idx >= 0:
101
+ # keep from the last Thought: (include the marker so parser sees it)
102
+ text = text[last_thought_idx + 1 :] # +1 to keep leading newline trimmed
103
 
104
+ # 4) Truncate to keep the most recent reasoning / final answer
105
+ if len(text) > self.max_chars:
106
+ text = text[-self.max_chars :]
107
 
108
+ # 5) Strip leading/trailing whitespace
109
+ return text.strip()
110
+
111
+ def _extract_text(self, raw):
112
+ """
113
+ Handle possible return formats:
114
+ - plain str
115
+ - list/dict results from HF pipeline
116
+ - objects exposing .content or ['generated_text']
117
+ """
118
+ # Direct string
119
+ if isinstance(raw, str):
120
+ return raw
121
+
122
+ # If it's a list (transformers pipeline may return list of dicts)
123
+ if isinstance(raw, (list, tuple)) and len(raw) > 0:
124
+ first = raw[0]
125
+ if isinstance(first, dict):
126
+ # common keys: 'generated_text', 'text'
127
+ for k in ("generated_text", "text", "output_text"):
128
+ if k in first:
129
+ return str(first[k])
130
+ # else stringify the dict
131
+ return str(first)
132
+ else:
133
+ return str(first)
134
+
135
+ # If it's a dict with 'generated_text' etc.
136
+ if isinstance(raw, dict):
137
+ for k in ("generated_text", "text", "output_text"):
138
+ if k in raw:
139
+ return str(raw[k])
140
+ # fallback to string repr
141
+ return str(raw)
142
+
143
+ # Last resort: string conversion
144
+ return str(raw)
145
+
146
+ def _identifying_params(self):
147
+ return {"inner": getattr(self.inner_llm, "_llm_type", None), "max_chars": self.max_chars}
148
+
149
+ # --- Completely rewritten LangChainAgentWrapper (drop-in) ---
150
  class LangChainAgentWrapper:
151
  """
152
+ Rewritten, robust LangChain agent wrapper:
153
+ - loads Gemma model (model_id variable)
154
+ - wraps HF pipeline into HuggingFacePipeline (LangChain)
155
+ - wraps that into SlicedLLM to truncate / clean model outputs
156
+ - builds ReAct prompt (contains {tools} and {tool_names})
157
+ - creates agent with create_react_agent + AgentExecutor
158
  """
159
 
160
+ def __init__(
161
+ self,
162
+ model_id: str = "google/gemma-2b-it",
163
+ max_new_tokens: int = 96,
164
+ max_chars: int = 2048,
165
+ max_iterations: int = 2,
166
+ ):
167
  print("Initializing LangChainAgentWrapper...")
 
 
 
 
 
168
  try:
169
+ # Lazy/delayed imports
170
  from langchain.agents import AgentExecutor, create_react_agent
171
  from langchain_community.tools import DuckDuckGoSearchRun
172
 
173
+ # --- Tokenizer & Model ---
174
  print(f"Loading tokenizer for: {model_id}")
175
  tokenizer = AutoTokenizer.from_pretrained(model_id)
176
 
177
  print(f"Loading model: {model_id}")
 
178
  model = AutoModelForCausalLM.from_pretrained(
179
  model_id,
180
  torch_dtype=torch.bfloat16,
181
  device_map="auto",
182
+ offload_folder="offload",
183
  )
184
  print("Model loaded successfully.")
185
  print(f"Allocated: {memory_allocated()/1e9:.2f} GB | Reserved: {memory_reserved()/1e9:.2f} GB")
186
 
187
+ # --- HF pipeline (transformers) with safe defaults ---
 
 
188
  llm_pipeline = transformers.pipeline(
189
  "text-generation",
190
  model=model,
191
  tokenizer=tokenizer,
192
+ max_new_tokens=max_new_tokens,
193
+ return_full_text=False,
194
  pad_token_id=tokenizer.eos_token_id,
195
  eos_token_id=tokenizer.eos_token_id,
196
  )
197
+ print("Transformers pipeline created successfully.")
198
+
199
+ # --- Wrap pipeline into LangChain HuggingFacePipeline LLM ---
200
+ base_lc_llm = HuggingFacePipeline(pipeline=llm_pipeline)
201
+ # --- Wrap that LLM into our slicer to keep outputs trimmed and to strip instruction echoes ---
202
+ self.llm = SlicedLLM(base_lc_llm, max_chars=max_chars)
203
 
204
+ print("SlicedLLM wrapper created successfully.")
 
 
 
205
 
206
+ # --- Tools ---
207
  print("Defining tools...")
208
  search_tool = DuckDuckGoSearchRun(
209
  name="web_search",
210
+ description="Web search via DuckDuckGo for up-to-date facts/events."
211
  )
212
 
 
213
  self.tools = [
214
  Tool(
215
  name="get_current_time_in_timezone",
 
223
  description=safe_calculator_func.__doc__
224
  ),
225
  ]
226
+ print(f"Tools prepared: {[t.name for t in self.tools]}")
227
 
228
+ # --- ReAct prompt (must contain {tools} and {tool_names}) ---
 
 
229
  react_prompt = PromptTemplate(
230
  input_variables=["tools", "tool_names", "agent_scratchpad"],
231
  template="""
232
+ DO NOT REPEAT OR PARAPHRASE ANY PART OF THIS PROMPT.
233
+
234
+ You are an assistant that strictly follows the ReAct format.
235
+
236
+ You can use these tools:
237
+ {tools}
238
+
239
+ Valid tool names: {tool_names}
240
+
241
+ When responding, follow this exact grammar and include nothing else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ Thought: <brief reasoning>
244
+ Action: <one of {tool_names} OR "none">
245
+ Action Input: <input for the action>
246
 
247
+ (If you choose an action other than "none", the system will insert an Observation before you continue.)
248
+ If Action is "none", finish by outputting:
249
+ Final Answer: <short direct answer>
250
+
251
+ {agent_scratchpad}
252
+ Thought:
253
+ """,
254
+ )
255
+
256
+ # --- Create agent + executor ---
257
  print("Creating agent...")
258
  agent = create_react_agent(self.llm, self.tools, react_prompt)
259
 
 
 
 
 
260
  self.agent_executor = AgentExecutor(
261
  agent=agent,
262
  tools=self.tools,
263
  verbose=True,
264
  handle_parsing_errors=True,
265
+ max_iterations=max_iterations,
266
  )
267
  print("LangChain agent created successfully.")
268
 
 
271
  traceback.print_exc()
272
  raise RuntimeError(f"LangChain agent initialization failed: {e}") from e
273
 
274
+ def __call__(self, question: str) -> str:
275
+ """
276
+ Run the agent on a single question.
277
+ We rely on AgentExecutor to manage the ReAct loops.
278
+ """
279
+ print(f"\n--- LangChainAgentWrapper received question: {question[:140]}... ---")
280
+ try:
281
+ # AgentExecutor expects {"input": question}
282
+ response = self.agent_executor.invoke({"input": question})
283
+ return response.get("output", "No output found.")
284
+ except Exception as e:
285
+ print(f"ERROR: LangChain agent execution failed: {e}")
286
+ traceback.print_exc()
287
+ # Return an informative string so the outer code can still submit something
288
+ return f"Agent Error: Failed to process the question. Details: {e}"
289
+
290
  def __call__(self, question: str) -> str:
291
  """
292
  Run the agent on a single question. We rely on the AgentExecutor to manage