| """Search tool for agent.""" | |
| from langchain_core.tools import tool | |
| from src.rag.retriever import retriever | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from src.middlewares.logging import get_logger | |
| logger = get_logger("search_tool") | |
| async def search_documents( | |
| query: str, | |
| user_id: str, | |
| db: AsyncSession, | |
| num_results: int = 5 | |
| ) -> str: | |
| """Search user's uploaded documents for relevant information. | |
| Args: | |
| query: The search query or question | |
| user_id: The user's ID | |
| db: Database session | |
| num_results: Number of results to return (default: 5) | |
| Returns: | |
| Relevant document excerpts with source and page information | |
| """ | |
| try: | |
| results = await retriever.retrieve(query, user_id, db, num_results) | |
| if not results: | |
| return "No relevant information found in the documents." | |
| formatted_results = [] | |
| for result in results: | |
| filename = result.metadata.get("filename", "Unknown") | |
| page = result.metadata.get("page_label") | |
| source_label = f"{filename}, p.{page}" if page else filename | |
| formatted_results.append(f"[Source: {source_label}]\n{result.content}\n") | |
| return "\n".join(formatted_results) | |
| except Exception as e: | |
| logger.error("Search failed", error=str(e)) | |
| return "Sorry, I encountered an error while searching the documents." | |