| | |
| | import os |
| | import asyncio |
| | import json |
| | import hashlib |
| | import shutil |
| | from io import BytesIO, StringIO |
| | from typing import List, Tuple |
| |
|
| | import gradio as gr |
| | import numpy as np |
| | import faiss |
| | import requests |
| | import pandas as pd |
| | from sentence_transformers import SentenceTransformer |
| | import fitz |
| | import docx |
| | from pptx import Presentation |
| | from crawl4ai import AsyncWebCrawler |
| |
|
| | |
| | OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") |
| | OPENROUTER_MODEL = "nvidia/nemotron-nano-12b-v2-vl:free" |
| | EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" |
| | CACHE_DIR = "./cache" |
| | SYSTEM_PROMPT = "You are a helpful assistant." |
| | os.makedirs(CACHE_DIR, exist_ok=True) |
| |
|
| | embedder = SentenceTransformer(EMBEDDING_MODEL_NAME) |
| |
|
| | DOCS: List[str] = [] |
| | FILENAMES: List[str] = [] |
| | EMBEDDINGS: np.ndarray = None |
| | FAISS_INDEX = None |
| | CURRENT_CACHE_KEY: str = "" |
| |
|
| |
|
| | |
| | async def clear_cache_every_5min(): |
| | while True: |
| | await asyncio.sleep(300) |
| | try: |
| | if os.path.exists(CACHE_DIR): |
| | shutil.rmtree(CACHE_DIR) |
| | os.makedirs(CACHE_DIR, exist_ok=True) |
| | print("🧹 Cache cleared successfully.") |
| | except Exception as e: |
| | print(f"[Cache cleanup error] {e}") |
| |
|
| | |
| | asyncio.get_event_loop().create_task(clear_cache_every_5min()) |
| |
|
| |
|
| | |
| | def extract_text_from_pdf(file_bytes: bytes) -> str: |
| | try: |
| | doc = fitz.open(stream=file_bytes, filetype="pdf") |
| | return "\n".join(page.get_text() for page in doc) |
| | except Exception as e: |
| | return f"[PDF extraction error] {e}" |
| |
|
| | def extract_text_from_docx(file_bytes: bytes) -> str: |
| | try: |
| | f = BytesIO(file_bytes) |
| | doc = docx.Document(f) |
| | return "\n".join(p.text for p in doc.paragraphs) |
| | except Exception as e: |
| | return f"[DOCX extraction error] {e}" |
| |
|
| | def extract_text_from_txt(file_bytes: bytes) -> str: |
| | try: |
| | return file_bytes.decode("utf-8", errors="ignore") |
| | except Exception as e: |
| | return f"[TXT extraction error] {e}" |
| |
|
| | def extract_text_from_excel(file_bytes: bytes) -> str: |
| | try: |
| | f = BytesIO(file_bytes) |
| | df = pd.read_excel(f, dtype=str) |
| | return "\n".join("\n".join(df[col].fillna("").astype(str).tolist()) for col in df.columns) |
| | except Exception as e: |
| | return f"[EXCEL extraction error] {e}" |
| |
|
| | def extract_text_from_pptx(file_bytes: bytes) -> str: |
| | try: |
| | f = BytesIO(file_bytes) |
| | prs = Presentation(f) |
| | texts = [] |
| | for slide in prs.slides: |
| | for shape in slide.shapes: |
| | if hasattr(shape, "text"): |
| | texts.append(shape.text) |
| | return "\n".join(texts) |
| | except Exception as e: |
| | return f"[PPTX extraction error] {e}" |
| |
|
| | def extract_text_from_csv(file_bytes: bytes) -> str: |
| | try: |
| | f = StringIO(file_bytes.decode("utf-8", errors="ignore")) |
| | df = pd.read_csv(f, dtype=str) |
| | return df.to_string(index=False) |
| | except Exception as e: |
| | return f"[CSV extraction error] {e}" |
| |
|
| | def extract_text_from_file_tuple(file_tuple) -> Tuple[str, bytes]: |
| | try: |
| | if hasattr(file_tuple, "name") and hasattr(file_tuple, "read"): |
| | return os.path.basename(file_tuple.name), file_tuple.read() |
| | except Exception: |
| | pass |
| | if isinstance(file_tuple, tuple) and len(file_tuple) == 2 and isinstance(file_tuple[1], (bytes, bytearray)): |
| | return file_tuple[0], bytes(file_tuple[1]) |
| | if isinstance(file_tuple, str) and os.path.exists(file_tuple): |
| | with open(file_tuple, "rb") as fh: |
| | return os.path.basename(file_tuple), fh.read() |
| | raise ValueError("Unsupported file object passed by Gradio.") |
| |
|
| | def extract_text_by_ext(filename: str, file_bytes: bytes) -> str: |
| | name = filename.lower() |
| | if name.endswith(".pdf"): return extract_text_from_pdf(file_bytes) |
| | if name.endswith(".docx"): return extract_text_from_docx(file_bytes) |
| | if name.endswith(".txt"): return extract_text_from_txt(file_bytes) |
| | if name.endswith((".xlsx", ".xls")): return extract_text_from_excel(file_bytes) |
| | if name.endswith(".pptx"): return extract_text_from_pptx(file_bytes) |
| | if name.endswith(".csv"): return extract_text_from_csv(file_bytes) |
| | return extract_text_from_txt(file_bytes) |
| |
|
| |
|
| | |
| | def make_cache_key_for_files(files: List[Tuple[str, bytes]]) -> str: |
| | h = hashlib.sha256() |
| | for name, b in sorted(files, key=lambda x: x[0]): |
| | h.update(name.encode()) |
| | h.update(str(len(b)).encode()) |
| | h.update(hashlib.sha256(b).digest()) |
| | return h.hexdigest() |
| |
|
| | def cache_save_embeddings(cache_key: str, embeddings: np.ndarray, filenames: List[str]): |
| | np.savez_compressed(os.path.join(CACHE_DIR, f"{cache_key}.npz"), embeddings=embeddings, filenames=np.array(filenames)) |
| |
|
| | def cache_load_embeddings(cache_key: str): |
| | path = os.path.join(CACHE_DIR, f"{cache_key}.npz") |
| | if not os.path.exists(path): return None |
| | try: |
| | arr = np.load(path, allow_pickle=True) |
| | return arr["embeddings"], arr["filenames"].tolist() |
| | except Exception: |
| | return None |
| |
|
| | def build_faiss_index(embeddings: np.ndarray): |
| | global FAISS_INDEX |
| | if embeddings is None or len(embeddings) == 0: |
| | FAISS_INDEX = None |
| | return None |
| | emb = embeddings.astype("float32") |
| | index = faiss.IndexFlatL2(emb.shape[1]) |
| | index.add(emb) |
| | FAISS_INDEX = index |
| | return index |
| |
|
| | def search_top_k(query: str, k: int = 3): |
| | if FAISS_INDEX is None: |
| | return [] |
| | q_emb = embedder.encode([query], convert_to_numpy=True).astype("float32") |
| | D, I = FAISS_INDEX.search(q_emb, k) |
| | return [{"index": int(i), "distance": float(d), "text": DOCS[i], "source": FILENAMES[i]} for d, i in zip(D[0], I[0]) if i >= 0] |
| |
|
| |
|
| | |
| | def openrouter_chat_system_user(user_prompt: str): |
| | """ |
| | Sends user prompt to OpenRouter and expects a plain text response. |
| | """ |
| | if not OPENROUTER_API_KEY: |
| | return "[OpenRouter error] Missing OPENROUTER_API_KEY." |
| |
|
| | url = "https://openrouter.ai/api/v1/chat/completions" |
| | headers = { |
| | "Authorization": f"Bearer {OPENROUTER_API_KEY}", |
| | "Content-Type": "application/json", |
| | } |
| |
|
| | |
| | payload = { |
| | "model": OPENROUTER_MODEL, |
| | "messages": [ |
| | {"role": "system", "content": SYSTEM_PROMPT + " Always respond in plain text. Avoid JSON or markdown formatting."}, |
| | {"role": "user", "content": user_prompt}, |
| | ], |
| | } |
| |
|
| | try: |
| | r = requests.post(url, headers=headers, json=payload, timeout=60) |
| | r.raise_for_status() |
| | obj = r.json() |
| |
|
| | |
| | if "choices" in obj and obj["choices"]: |
| | choice = obj["choices"][0] |
| | if "message" in choice and "content" in choice["message"]: |
| | text = choice["message"]["content"] |
| | |
| | text = text.strip().replace("```", "").replace("json", "") |
| | return text |
| | elif "text" in choice: |
| | return choice["text"].strip() |
| | return "[OpenRouter] Unexpected response format." |
| |
|
| | except Exception as e: |
| | return f"[OpenRouter request error] {e}" |
| |
|
| |
|
| | |
| | async def _crawl_async_get_markdown(url: str): |
| | async with AsyncWebCrawler() as crawler: |
| | result = await crawler.arun(url=url) |
| | if hasattr(result, "success") and result.success is False: |
| | return f"[Crawl4AI error] {getattr(result, 'error_message', '[Unknown error]')}" |
| | md_obj = getattr(result, "markdown", None) |
| | if md_obj: |
| | return getattr(md_obj, "fit_markdown", None) or getattr(md_obj, "raw_markdown", None) or str(md_obj) |
| | return getattr(result, "text", None) or getattr(result, "html", None) or "[Crawl4AI returned no usable fields]" |
| |
|
| | def crawl_url_sync(url: str) -> str: |
| | try: |
| | return asyncio.run(_crawl_async_get_markdown(url)) |
| | except Exception as e: |
| | return f"[Crawl4AI runtime error] {e}" |
| |
|
| |
|
| | |
| | def upload_and_index(files): |
| | global DOCS, FILENAMES, EMBEDDINGS, CURRENT_CACHE_KEY |
| | if not files: |
| | return "No files uploaded.", "" |
| | prepared = [(name := extract_text_from_file_tuple(f)[0], extract_text_from_file_tuple(f)[1]) for f in files] |
| | previews = [{"name": n, "size": len(b)} for n, b in prepared] |
| | cache_key = make_cache_key_for_files(prepared) |
| | CURRENT_CACHE_KEY = cache_key |
| | cached = cache_load_embeddings(cache_key) |
| | if cached: |
| | emb, filenames = cached |
| | EMBEDDINGS = np.array(emb) |
| | FILENAMES = filenames |
| | DOCS = [extract_text_by_ext(n, b) for n, b in prepared] |
| | build_faiss_index(EMBEDDINGS) |
| | return f"Loaded embeddings from cache ({len(FILENAMES)} docs).", json.dumps(previews) |
| | DOCS, FILENAMES = zip(*[(extract_text_by_ext(n, b), n) for n, b in prepared]) |
| | EMBEDDINGS = embedder.encode(DOCS, convert_to_numpy=True, show_progress_bar=False).astype("float32") |
| | cache_save_embeddings(cache_key, EMBEDDINGS, FILENAMES) |
| | build_faiss_index(EMBEDDINGS) |
| | return f"Uploaded and indexed {len(DOCS)} documents.", json.dumps(previews) |
| |
|
| | def crawl_and_index(url: str): |
| | global DOCS, FILENAMES, EMBEDDINGS, CURRENT_CACHE_KEY |
| | if not url: |
| | return "No URL provided.", "" |
| | crawled = crawl_url_sync(url) |
| | if crawled.startswith("[Crawl4AI"): |
| | return crawled, "" |
| | key_hash = hashlib.sha256((url + crawled).encode()).hexdigest() |
| | CURRENT_CACHE_KEY = key_hash |
| | cached = cache_load_embeddings(key_hash) |
| | if cached: |
| | emb, filenames = cached |
| | EMBEDDINGS = np.array(emb) |
| | FILENAMES = filenames |
| | DOCS = [crawled] |
| | build_faiss_index(EMBEDDINGS) |
| | return f"Crawled and loaded embeddings from cache for {url}", crawled[:20000] |
| | DOCS, FILENAMES = [crawled], [url] |
| | EMBEDDINGS = embedder.encode(DOCS, convert_to_numpy=True, show_progress_bar=False).astype("float32") |
| | cache_save_embeddings(key_hash, EMBEDDINGS, FILENAMES) |
| | build_faiss_index(EMBEDDINGS) |
| | return f"Crawled and indexed {url}", crawled[:20000] |
| |
|
| | def ask_question(question: str): |
| | if not question: |
| | return "Please enter a question." |
| | if not DOCS or FAISS_INDEX is None: |
| | return "No indexed data found." |
| | results = search_top_k(question, k=3) |
| | if not results: |
| | return "No relevant documents found." |
| | context = "\n".join(f"Source: {r['source']}\n\n{r['text'][:18000]}\n---\n" for r in results) |
| | user_prompt = f"Use the following context to answer the question.\n\nContext:\n{context}\nQuestion: {question}\nAnswer:" |
| | return openrouter_chat_system_user(user_prompt) |
| |
|
| |
|
| | |
| | with gr.Blocks(title="AI Ally — Crawl4AI + OpenRouter + FAISS") as demo: |
| | gr.Markdown("# 🤖 AI Ally — Document & Website QA\nCrawl4AI for websites, file uploads for docs. FAISS retrieval + sentence-transformers + OpenRouter LLM.") |
| |
|
| | with gr.Tab("Documents"): |
| | file_input = gr.File(label="Upload files", file_count="multiple", |
| | file_types=[".pdf", ".docx", ".txt", ".xlsx", ".pptx", ".csv"]) |
| | upload_btn = gr.Button("Upload & Index") |
| | upload_status = gr.Textbox(label="Status", interactive=False) |
| | preview_box = gr.Textbox(label="Uploads (preview JSON)", interactive=False) |
| | upload_btn.click(upload_and_index, inputs=[file_input], outputs=[upload_status, preview_box]) |
| |
|
| | gr.Markdown("### Ask about your documents") |
| | q = gr.Textbox(label="Question", lines=3) |
| | ask_btn = gr.Button("Ask") |
| | answer_out = gr.Textbox(label="Answer", interactive=False, lines=15) |
| | ask_btn.click(ask_question, inputs=[q], outputs=[answer_out]) |
| |
|
| | with gr.Tab("Website Crawl"): |
| | url = gr.Textbox(label="URL to crawl") |
| | crawl_btn = gr.Button("Crawl & Index") |
| | crawl_status = gr.Textbox(label="Status", interactive=False) |
| | crawl_preview = gr.Textbox(label="Crawl preview", interactive=False) |
| | crawl_btn.click(crawl_and_index, inputs=[url], outputs=[crawl_status, crawl_preview]) |
| |
|
| | q2 = gr.Textbox(label="Question", lines=3) |
| | ask_btn2 = gr.Button("Ask site") |
| | answer_out2 = gr.Textbox(label="Answer", interactive=False, lines=15) |
| | ask_btn2.click(ask_question, inputs=[q2], outputs=[answer_out2]) |
| |
|
| | with gr.Tab("Settings / Info"): |
| | gr.Markdown(f"- Model: `{OPENROUTER_MODEL}`") |
| | gr.Markdown(f"- Embedding model: `{EMBEDDING_MODEL_NAME}`") |
| | gr.Markdown(f"- Cache clears automatically every 5 minutes.") |
| | gr.Markdown(f"- System prompt is fixed internally: `{SYSTEM_PROMPT}`") |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) |
| |
|