Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -58,78 +58,158 @@ def safe_calculator_func(expression: str) -> str:
|
|
| 58 |
return f"Error calculating '{expression}': Invalid expression or calculation error ({e})."
|
| 59 |
|
| 60 |
class SlicedLLM(LLM):
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
class LangChainAgentWrapper:
|
| 70 |
"""
|
| 71 |
-
|
| 72 |
-
-
|
| 73 |
-
-
|
| 74 |
-
-
|
| 75 |
-
-
|
| 76 |
-
-
|
| 77 |
"""
|
| 78 |
|
| 79 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
print("Initializing LangChainAgentWrapper...")
|
| 81 |
-
|
| 82 |
-
# NOTE: Pick the exact model id you intend to load here.
|
| 83 |
-
# Keep model_id variable consistent so you don't accidentally load a different model.
|
| 84 |
-
model_id = "google/gemma-2b-it"
|
| 85 |
-
|
| 86 |
try:
|
| 87 |
-
#
|
| 88 |
from langchain.agents import AgentExecutor, create_react_agent
|
| 89 |
from langchain_community.tools import DuckDuckGoSearchRun
|
| 90 |
|
| 91 |
-
# --- Tokenizer & Model
|
| 92 |
print(f"Loading tokenizer for: {model_id}")
|
| 93 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 94 |
|
| 95 |
print(f"Loading model: {model_id}")
|
| 96 |
-
# Use AutoModelForCausalLM for Gemma; bfloat16 for Ada / L40S style cards
|
| 97 |
model = AutoModelForCausalLM.from_pretrained(
|
| 98 |
model_id,
|
| 99 |
torch_dtype=torch.bfloat16,
|
| 100 |
device_map="auto",
|
| 101 |
-
offload_folder="offload",
|
| 102 |
)
|
| 103 |
print("Model loaded successfully.")
|
| 104 |
print(f"Allocated: {memory_allocated()/1e9:.2f} GB | Reserved: {memory_reserved()/1e9:.2f} GB")
|
| 105 |
|
| 106 |
-
# ---
|
| 107 |
-
# return_full_text=False avoids echoing the whole prompt
|
| 108 |
-
# We set a conservative max_new_tokens suitable for ReAct loops and small models
|
| 109 |
llm_pipeline = transformers.pipeline(
|
| 110 |
"text-generation",
|
| 111 |
model=model,
|
| 112 |
tokenizer=tokenizer,
|
| 113 |
-
max_new_tokens=
|
| 114 |
-
return_full_text=False,
|
| 115 |
pad_token_id=tokenizer.eos_token_id,
|
| 116 |
eos_token_id=tokenizer.eos_token_id,
|
| 117 |
)
|
| 118 |
-
print("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
-
|
| 121 |
-
# (langchain_huggingface.HuggingFacePipeline expects a transformers pipeline)
|
| 122 |
-
self.llm = HuggingFacePipeline(pipeline=llm_pipeline)
|
| 123 |
-
print("HuggingFacePipeline wrapper created.")
|
| 124 |
|
| 125 |
-
# ---
|
| 126 |
print("Defining tools...")
|
| 127 |
search_tool = DuckDuckGoSearchRun(
|
| 128 |
name="web_search",
|
| 129 |
-
description="
|
| 130 |
)
|
| 131 |
|
| 132 |
-
# Ensure these Tool.name values exactly match the strings in the prompt's action list
|
| 133 |
self.tools = [
|
| 134 |
Tool(
|
| 135 |
name="get_current_time_in_timezone",
|
|
@@ -143,57 +223,46 @@ class LangChainAgentWrapper:
|
|
| 143 |
description=safe_calculator_func.__doc__
|
| 144 |
),
|
| 145 |
]
|
| 146 |
-
print(f"Tools prepared
|
| 147 |
|
| 148 |
-
# --- ReAct
|
| 149 |
-
# Important: keep this prompt short and *do not* encourage repetition.
|
| 150 |
-
# We include Observation because LangChain inserts the tool result back into the scratchpad.
|
| 151 |
react_prompt = PromptTemplate(
|
| 152 |
input_variables=["tools", "tool_names", "agent_scratchpad"],
|
| 153 |
template="""
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
Thought: <reasoning>
|
| 166 |
-
Action: <one of {tool_names} OR "none">
|
| 167 |
-
Action Input: <JSON input for the tool>
|
| 168 |
-
|
| 169 |
-
(If Action is not "none", the system will provide an Observation.)
|
| 170 |
-
(After the Observation, you continue with another Thought/Action loop.)
|
| 171 |
-
|
| 172 |
-
If you choose Action: none, you MUST end with:
|
| 173 |
-
Final Answer: <your final answer>
|
| 174 |
-
|
| 175 |
-
Begin your reasoning now.
|
| 176 |
-
|
| 177 |
-
{agent_scratchpad}
|
| 178 |
-
Thought:
|
| 179 |
-
"""
|
| 180 |
-
)
|
| 181 |
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
print("Creating agent...")
|
| 185 |
agent = create_react_agent(self.llm, self.tools, react_prompt)
|
| 186 |
|
| 187 |
-
# NOTE: We intentionally do NOT add a ConversationBufferWindowMemory here.
|
| 188 |
-
# ReAct agents benefit from the explicit scratchpad pattern; adding memory can
|
| 189 |
-
# sometimes re-introduce context and cause token growth. If you really want memory,
|
| 190 |
-
# prefer a tiny window and test thoroughly.
|
| 191 |
self.agent_executor = AgentExecutor(
|
| 192 |
agent=agent,
|
| 193 |
tools=self.tools,
|
| 194 |
verbose=True,
|
| 195 |
handle_parsing_errors=True,
|
| 196 |
-
max_iterations=
|
| 197 |
)
|
| 198 |
print("LangChain agent created successfully.")
|
| 199 |
|
|
@@ -202,6 +271,22 @@ class LangChainAgentWrapper:
|
|
| 202 |
traceback.print_exc()
|
| 203 |
raise RuntimeError(f"LangChain agent initialization failed: {e}") from e
|
| 204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
def __call__(self, question: str) -> str:
|
| 206 |
"""
|
| 207 |
Run the agent on a single question. We rely on the AgentExecutor to manage
|
|
|
|
| 58 |
return f"Error calculating '{expression}': Invalid expression or calculation error ({e})."
|
| 59 |
|
| 60 |
class SlicedLLM(LLM):
|
| 61 |
+
"""
|
| 62 |
+
Light wrapper around any LangChain LLM (we'll use the HuggingFacePipeline wrapper).
|
| 63 |
+
Responsibilities:
|
| 64 |
+
- Call the inner LLM
|
| 65 |
+
- Extract text robustly from different return shapes
|
| 66 |
+
- Truncate to `max_chars` from the end (keeps the most recent reasoning)
|
| 67 |
+
- Strip instruction echoing by keeping from the last 'Thought:' if present
|
| 68 |
+
"""
|
| 69 |
+
def __init__(self, inner_llm, max_chars: int = 2048, **kwargs):
|
| 70 |
+
self.inner_llm = inner_llm
|
| 71 |
+
self.max_chars = int(max_chars)
|
| 72 |
+
# required for LangChain LLM subclasses
|
| 73 |
+
self.max_retries = kwargs.get("max_retries", 1)
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def _llm_type(self) -> str:
|
| 77 |
+
return "sliced-llm"
|
| 78 |
+
|
| 79 |
+
def _call(self, prompt: str, stop=None) -> str:
|
| 80 |
+
"""
|
| 81 |
+
Core call entrypoint used by LangChain. We call the inner LLM and then post-process.
|
| 82 |
+
"""
|
| 83 |
+
# 1) Call inner LLM (it may expose _call or be callable)
|
| 84 |
+
raw = None
|
| 85 |
+
# inner may be a LangChain LLM (with _call) or a callable pipeline
|
| 86 |
+
if hasattr(self.inner_llm, "_call"):
|
| 87 |
+
raw = self.inner_llm._call(prompt, stop=stop)
|
| 88 |
+
else:
|
| 89 |
+
# fallback - call and try to extract text
|
| 90 |
+
# Many pipeline wrappers accept a string and return text or list
|
| 91 |
+
raw = self.inner_llm(prompt)
|
| 92 |
+
|
| 93 |
+
# 2) Extract text from common return shapes
|
| 94 |
+
text = self._extract_text(raw)
|
| 95 |
+
|
| 96 |
+
# 3) Attempt to remove repeated instruction blocks by finding last 'Thought:' anchor
|
| 97 |
+
# We keep text from the last "Thought:" onward if that appears in the output.
|
| 98 |
+
# This removes prompt-echoed instruction blocks that often appear earlier in the string.
|
| 99 |
+
last_thought_idx = text.rfind("\nThought:")
|
| 100 |
+
if last_thought_idx >= 0:
|
| 101 |
+
# keep from the last Thought: (include the marker so parser sees it)
|
| 102 |
+
text = text[last_thought_idx + 1 :] # +1 to keep leading newline trimmed
|
| 103 |
|
| 104 |
+
# 4) Truncate to keep the most recent reasoning / final answer
|
| 105 |
+
if len(text) > self.max_chars:
|
| 106 |
+
text = text[-self.max_chars :]
|
| 107 |
|
| 108 |
+
# 5) Strip leading/trailing whitespace
|
| 109 |
+
return text.strip()
|
| 110 |
+
|
| 111 |
+
def _extract_text(self, raw):
|
| 112 |
+
"""
|
| 113 |
+
Handle possible return formats:
|
| 114 |
+
- plain str
|
| 115 |
+
- list/dict results from HF pipeline
|
| 116 |
+
- objects exposing .content or ['generated_text']
|
| 117 |
+
"""
|
| 118 |
+
# Direct string
|
| 119 |
+
if isinstance(raw, str):
|
| 120 |
+
return raw
|
| 121 |
+
|
| 122 |
+
# If it's a list (transformers pipeline may return list of dicts)
|
| 123 |
+
if isinstance(raw, (list, tuple)) and len(raw) > 0:
|
| 124 |
+
first = raw[0]
|
| 125 |
+
if isinstance(first, dict):
|
| 126 |
+
# common keys: 'generated_text', 'text'
|
| 127 |
+
for k in ("generated_text", "text", "output_text"):
|
| 128 |
+
if k in first:
|
| 129 |
+
return str(first[k])
|
| 130 |
+
# else stringify the dict
|
| 131 |
+
return str(first)
|
| 132 |
+
else:
|
| 133 |
+
return str(first)
|
| 134 |
+
|
| 135 |
+
# If it's a dict with 'generated_text' etc.
|
| 136 |
+
if isinstance(raw, dict):
|
| 137 |
+
for k in ("generated_text", "text", "output_text"):
|
| 138 |
+
if k in raw:
|
| 139 |
+
return str(raw[k])
|
| 140 |
+
# fallback to string repr
|
| 141 |
+
return str(raw)
|
| 142 |
+
|
| 143 |
+
# Last resort: string conversion
|
| 144 |
+
return str(raw)
|
| 145 |
+
|
| 146 |
+
def _identifying_params(self):
|
| 147 |
+
return {"inner": getattr(self.inner_llm, "_llm_type", None), "max_chars": self.max_chars}
|
| 148 |
+
|
| 149 |
+
# --- Completely rewritten LangChainAgentWrapper (drop-in) ---
|
| 150 |
class LangChainAgentWrapper:
|
| 151 |
"""
|
| 152 |
+
Rewritten, robust LangChain agent wrapper:
|
| 153 |
+
- loads Gemma model (model_id variable)
|
| 154 |
+
- wraps HF pipeline into HuggingFacePipeline (LangChain)
|
| 155 |
+
- wraps that into SlicedLLM to truncate / clean model outputs
|
| 156 |
+
- builds ReAct prompt (contains {tools} and {tool_names})
|
| 157 |
+
- creates agent with create_react_agent + AgentExecutor
|
| 158 |
"""
|
| 159 |
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
model_id: str = "google/gemma-2b-it",
|
| 163 |
+
max_new_tokens: int = 96,
|
| 164 |
+
max_chars: int = 2048,
|
| 165 |
+
max_iterations: int = 2,
|
| 166 |
+
):
|
| 167 |
print("Initializing LangChainAgentWrapper...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
try:
|
| 169 |
+
# Lazy/delayed imports
|
| 170 |
from langchain.agents import AgentExecutor, create_react_agent
|
| 171 |
from langchain_community.tools import DuckDuckGoSearchRun
|
| 172 |
|
| 173 |
+
# --- Tokenizer & Model ---
|
| 174 |
print(f"Loading tokenizer for: {model_id}")
|
| 175 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 176 |
|
| 177 |
print(f"Loading model: {model_id}")
|
|
|
|
| 178 |
model = AutoModelForCausalLM.from_pretrained(
|
| 179 |
model_id,
|
| 180 |
torch_dtype=torch.bfloat16,
|
| 181 |
device_map="auto",
|
| 182 |
+
offload_folder="offload",
|
| 183 |
)
|
| 184 |
print("Model loaded successfully.")
|
| 185 |
print(f"Allocated: {memory_allocated()/1e9:.2f} GB | Reserved: {memory_reserved()/1e9:.2f} GB")
|
| 186 |
|
| 187 |
+
# --- HF pipeline (transformers) with safe defaults ---
|
|
|
|
|
|
|
| 188 |
llm_pipeline = transformers.pipeline(
|
| 189 |
"text-generation",
|
| 190 |
model=model,
|
| 191 |
tokenizer=tokenizer,
|
| 192 |
+
max_new_tokens=max_new_tokens,
|
| 193 |
+
return_full_text=False,
|
| 194 |
pad_token_id=tokenizer.eos_token_id,
|
| 195 |
eos_token_id=tokenizer.eos_token_id,
|
| 196 |
)
|
| 197 |
+
print("Transformers pipeline created successfully.")
|
| 198 |
+
|
| 199 |
+
# --- Wrap pipeline into LangChain HuggingFacePipeline LLM ---
|
| 200 |
+
base_lc_llm = HuggingFacePipeline(pipeline=llm_pipeline)
|
| 201 |
+
# --- Wrap that LLM into our slicer to keep outputs trimmed and to strip instruction echoes ---
|
| 202 |
+
self.llm = SlicedLLM(base_lc_llm, max_chars=max_chars)
|
| 203 |
|
| 204 |
+
print("SlicedLLM wrapper created successfully.")
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
+
# --- Tools ---
|
| 207 |
print("Defining tools...")
|
| 208 |
search_tool = DuckDuckGoSearchRun(
|
| 209 |
name="web_search",
|
| 210 |
+
description="Web search via DuckDuckGo for up-to-date facts/events."
|
| 211 |
)
|
| 212 |
|
|
|
|
| 213 |
self.tools = [
|
| 214 |
Tool(
|
| 215 |
name="get_current_time_in_timezone",
|
|
|
|
| 223 |
description=safe_calculator_func.__doc__
|
| 224 |
),
|
| 225 |
]
|
| 226 |
+
print(f"Tools prepared: {[t.name for t in self.tools]}")
|
| 227 |
|
| 228 |
+
# --- ReAct prompt (must contain {tools} and {tool_names}) ---
|
|
|
|
|
|
|
| 229 |
react_prompt = PromptTemplate(
|
| 230 |
input_variables=["tools", "tool_names", "agent_scratchpad"],
|
| 231 |
template="""
|
| 232 |
+
DO NOT REPEAT OR PARAPHRASE ANY PART OF THIS PROMPT.
|
| 233 |
+
|
| 234 |
+
You are an assistant that strictly follows the ReAct format.
|
| 235 |
+
|
| 236 |
+
You can use these tools:
|
| 237 |
+
{tools}
|
| 238 |
+
|
| 239 |
+
Valid tool names: {tool_names}
|
| 240 |
+
|
| 241 |
+
When responding, follow this exact grammar and include nothing else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
+
Thought: <brief reasoning>
|
| 244 |
+
Action: <one of {tool_names} OR "none">
|
| 245 |
+
Action Input: <input for the action>
|
| 246 |
|
| 247 |
+
(If you choose an action other than "none", the system will insert an Observation before you continue.)
|
| 248 |
+
If Action is "none", finish by outputting:
|
| 249 |
+
Final Answer: <short direct answer>
|
| 250 |
+
|
| 251 |
+
{agent_scratchpad}
|
| 252 |
+
Thought:
|
| 253 |
+
""",
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# --- Create agent + executor ---
|
| 257 |
print("Creating agent...")
|
| 258 |
agent = create_react_agent(self.llm, self.tools, react_prompt)
|
| 259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
self.agent_executor = AgentExecutor(
|
| 261 |
agent=agent,
|
| 262 |
tools=self.tools,
|
| 263 |
verbose=True,
|
| 264 |
handle_parsing_errors=True,
|
| 265 |
+
max_iterations=max_iterations,
|
| 266 |
)
|
| 267 |
print("LangChain agent created successfully.")
|
| 268 |
|
|
|
|
| 271 |
traceback.print_exc()
|
| 272 |
raise RuntimeError(f"LangChain agent initialization failed: {e}") from e
|
| 273 |
|
| 274 |
+
def __call__(self, question: str) -> str:
|
| 275 |
+
"""
|
| 276 |
+
Run the agent on a single question.
|
| 277 |
+
We rely on AgentExecutor to manage the ReAct loops.
|
| 278 |
+
"""
|
| 279 |
+
print(f"\n--- LangChainAgentWrapper received question: {question[:140]}... ---")
|
| 280 |
+
try:
|
| 281 |
+
# AgentExecutor expects {"input": question}
|
| 282 |
+
response = self.agent_executor.invoke({"input": question})
|
| 283 |
+
return response.get("output", "No output found.")
|
| 284 |
+
except Exception as e:
|
| 285 |
+
print(f"ERROR: LangChain agent execution failed: {e}")
|
| 286 |
+
traceback.print_exc()
|
| 287 |
+
# Return an informative string so the outer code can still submit something
|
| 288 |
+
return f"Agent Error: Failed to process the question. Details: {e}"
|
| 289 |
+
|
| 290 |
def __call__(self, question: str) -> str:
|
| 291 |
"""
|
| 292 |
Run the agent on a single question. We rely on the AgentExecutor to manage
|