Spaces:
Sleeping
Sleeping
| """ | |
| ConjunctionReservoir Document Chat β HuggingFace Space | |
| ======================================================= | |
| Upload any text or PDF document, then ask questions about it. | |
| Retrieval uses sentence-level conjunction scoring (no embeddings needed). | |
| Generation uses HuggingFace Inference API (free, no key required). | |
| """ | |
| import re | |
| import os | |
| import time | |
| import json | |
| import gradio as gr | |
| from pathlib import Path | |
| # ββ ConjunctionReservoir ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| from conjunctionreservoir import ConjunctionReservoir | |
| # ββ HuggingFace Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| from huggingface_hub import InferenceClient | |
| # ββ PDF support (optional) ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| import fitz # PyMuPDF | |
| PDF_SUPPORT = True | |
| except ImportError: | |
| try: | |
| import pypdf | |
| PDF_SUPPORT = True | |
| except ImportError: | |
| PDF_SUPPORT = False | |
| # ββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ββ Constants βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DEFAULT_MODEL = "Qwen/Qwen2.5-72B-Instruct" | |
| FALLBACK_MODEL = "HuggingFaceH4/zephyr-7b-beta" | |
| MAX_TOKENS = 512 | |
| MAX_HISTORY = 6 # turns to keep in context | |
| DEMO_TEXT = """The ConjunctionReservoir is a document retrieval system that asks not | |
| "do these query terms appear somewhere in this chunk?" but rather | |
| "do these query terms appear in the SAME SENTENCE?" | |
| This is grounded in auditory neuroscience. Norman-Haignere et al. (2025) | |
| showed that auditory cortex integration windows are time-yoked at approximately | |
| 80ms β they are fixed clocks, not expanding to cover arbitrary structure. | |
| The sentence is the text analog of this fixed window. | |
| NMDA receptors implement coincidence detection by requiring simultaneous | |
| presynaptic glutamate release and postsynaptic depolarization to open. | |
| This is a hard AND gate, not a weighted average. | |
| The conjunction_threshold parameter mirrors this: below the threshold, | |
| a sentence contributes zero score to the chunk β it is absent, not degraded. | |
| Benchmark results show ConjunctionReservoir achieves 100% Rank-1 Rate on | |
| conjunction-specific queries, compared to 60% for both BM25 and SweepBrain. | |
| It intentionally trades broad-query recall for precision on specific | |
| co-occurrence queries. Use threshold=0.0 to approach standard TF-IDF.""" | |
| # ββ Text extraction ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def extract_text_from_file(filepath: str) -> str: | |
| """Extract text from .txt or .pdf file.""" | |
| path = Path(filepath) | |
| ext = path.suffix.lower() | |
| if ext == ".pdf": | |
| if not PDF_SUPPORT: | |
| return "ERROR: PDF support not available. Please install PyMuPDF or pypdf." | |
| try: | |
| import fitz | |
| doc = fitz.open(filepath) | |
| return "\n\n".join(page.get_text() for page in doc) | |
| except Exception: | |
| try: | |
| from pypdf import PdfReader | |
| reader = PdfReader(filepath) | |
| return "\n\n".join(p.extract_text() or "" for p in reader.pages) | |
| except Exception as e: | |
| return f"ERROR reading PDF: {e}" | |
| elif ext in (".txt", ".md", ".rst", ".text"): | |
| try: | |
| return path.read_text(encoding="utf-8", errors="replace") | |
| except Exception as e: | |
| return f"ERROR reading file: {e}" | |
| else: | |
| try: | |
| return path.read_text(encoding="utf-8", errors="replace") | |
| except Exception as e: | |
| return f"ERROR: Unsupported file type {ext}. Try .txt or .pdf" | |
| # ββ LLM generation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_client(hf_token: str = "") -> InferenceClient: | |
| token = hf_token.strip() or os.environ.get("HF_TOKEN", "") | |
| return InferenceClient(token=token if token else None) | |
| def format_messages(system: str, history: list, user_msg: str) -> list: | |
| messages = [{"role": "system", "content": system}] | |
| for user_h, asst_h in history[-MAX_HISTORY:]: | |
| messages.append({"role": "user", "content": user_h}) | |
| messages.append({"role": "assistant", "content": asst_h}) | |
| messages.append({"role": "user", "content": user_msg}) | |
| return messages | |
| def stream_response(client, model, messages): | |
| """Stream tokens from HF Inference API.""" | |
| try: | |
| stream = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| max_tokens=MAX_TOKENS, | |
| stream=True, | |
| temperature=0.3, | |
| ) | |
| for chunk in stream: | |
| # FIX: Check if choices exists before accessing [0] | |
| if chunk.choices and len(chunk.choices) > 0: | |
| delta = chunk.choices[0].delta.content | |
| if delta: | |
| yield delta | |
| except Exception as e: | |
| # Try fallback model | |
| if model != FALLBACK_MODEL: | |
| try: | |
| stream = client.chat.completions.create( | |
| model=FALLBACK_MODEL, | |
| messages=messages, | |
| max_tokens=MAX_TOKENS, | |
| stream=True, | |
| temperature=0.3, | |
| ) | |
| for chunk in stream: | |
| # FIX: Check if choices exists before accessing [0] | |
| if chunk.choices and len(chunk.choices) > 0: | |
| delta = chunk.choices[0].delta.content | |
| if delta: | |
| yield delta | |
| return | |
| except Exception: | |
| pass | |
| yield f"\n\nβ οΈ Generation error: {e}\n\nTip: Add a HuggingFace token in Settings for better rate limits." | |
| # ββ Retrieval helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def best_sentence(chunk: str, q_tokens: set) -> tuple: | |
| sents = [s.strip() for s in re.split(r'[.!?]+', chunk) if len(s.strip()) > 10] | |
| best, best_cov = chunk[:80], 0.0 | |
| for s in sents: | |
| toks = set(re.findall(r'\b[a-zA-Z]{3,}\b', s.lower())) | |
| matches = sum(1 for qt in q_tokens if any(qt in t or t in qt for t in toks)) | |
| cov = matches / len(q_tokens) if q_tokens else 0.0 | |
| if cov > best_cov: | |
| best_cov, best = cov, s | |
| return best, best_cov | |
| def do_retrieve(retriever, query: str, threshold: float, n_chunks: int = 3): | |
| retriever.conjunction_threshold = threshold | |
| hits = retriever.retrieve(query, top_k=n_chunks, update_coverage=True) | |
| hits = [(c, s) for c, s in hits if s > 0] | |
| if not hits: | |
| # Loosen and retry | |
| old = retriever.conjunction_threshold | |
| retriever.conjunction_threshold = 0.0 | |
| hits = retriever.retrieve(query, top_k=2, update_coverage=False) | |
| retriever.conjunction_threshold = old | |
| hits = [(c, s) for c, s in hits if s > 0][:2] | |
| return hits | |
| def format_context_for_llm(hits: list) -> str: | |
| if not hits: | |
| return "No relevant passages found." | |
| return "\n\n---\n\n".join( | |
| f"[Passage {i} | relevance {score:.3f}]\n{chunk.strip()}" | |
| for i, (chunk, score) in enumerate(hits, 1) | |
| ) | |
| def format_retrieval_display(hits: list, q_tokens: set, elapsed_ms: float) -> str: | |
| if not hits: | |
| return f"β οΈ No passages matched (try lowering threshold) β’ {elapsed_ms:.0f}ms" | |
| lines = [f"π **{len(hits)} passages retrieved** β’ {elapsed_ms:.0f}ms\n"] | |
| for i, (chunk, score) in enumerate(hits, 1): | |
| sent, cov = best_sentence(chunk, q_tokens) | |
| preview = sent[:120] + ("β¦" if len(sent) > 120 else "") | |
| lines.append(f"**[{i}]** score={score:.3f} β *\"{preview}\"*") | |
| return "\n".join(lines) | |
| # ββ Main app state βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class AppState: | |
| def __init__(self): | |
| self.retriever = None | |
| self.doc_name = None | |
| self.doc_chars = 0 | |
| self.chat_history = [] # list of (user, assistant) for display | |
| self.llm_history = [] # list of (user_with_context, assistant) for LLM | |
| def reset_doc(self): | |
| self.retriever = None | |
| self.doc_name = None | |
| self.doc_chars = 0 | |
| self.reset_chat() | |
| def reset_chat(self): | |
| self.chat_history = [] | |
| self.llm_history = [] | |
| # ββ Build the Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def create_app(): | |
| state = AppState() | |
| # Load demo immediately | |
| def _load_demo(): | |
| state.reset_doc() | |
| r = ConjunctionReservoir(conjunction_threshold=0.4, coverage_decay=0.04) | |
| r.build_index(DEMO_TEXT, verbose=False) | |
| state.retriever = r | |
| state.doc_name = "ConjunctionReservoir Demo" | |
| state.doc_chars = len(DEMO_TEXT) | |
| s = r.summary() | |
| return ( | |
| f"β **{state.doc_name}** loaded \n" | |
| f"{s['n_chunks']} chunks β’ {s['n_sentences']} sentences β’ vocab {s['vocab_size']}" | |
| ) | |
| # ββ Gradio layout ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| css = """ | |
| #doc-status { border-left: 4px solid #4CAF50; padding: 8px 12px; background: #f9f9f9; border-radius: 4px; } | |
| #retrieval-info { font-size: 0.85em; color: #555; background: #f5f5f5; padding: 8px; border-radius: 4px; } | |
| .setting-row { display: flex; gap: 12px; align-items: center; } | |
| footer { display: none !important; } | |
| """ | |
| theme = gr.themes.Soft(primary_hue="blue", neutral_hue="slate") | |
| # Gradio 6.0 change: removed css and theme from Blocks init. | |
| with gr.Blocks( | |
| title="ConjunctionReservoir Document Chat", | |
| ) as demo: | |
| # ββ Header βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| gr.Markdown(""" | |
| # π§ ConjunctionReservoir Document Chat | |
| **Sentence-level conjunction retrieval** β terms must co-appear *in the same sentence* to score. | |
| Grounded in auditory neuroscience (Norman-Haignere 2025, Vollan 2025). Zero embeddings. Millisecond retrieval. | |
| """) | |
| with gr.Row(): | |
| # ββ Left column: document + settings ββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown("### π Document") | |
| with gr.Tab("Upload File"): | |
| file_input = gr.File( | |
| label="Upload .txt or .pdf", | |
| file_types=[".txt", ".pdf", ".md"], | |
| type="filepath", | |
| ) | |
| upload_btn = gr.Button("π₯ Load File", variant="primary") | |
| with gr.Tab("Paste Text"): | |
| text_input = gr.Textbox( | |
| label="Paste your text here", | |
| lines=8, | |
| placeholder="Paste any text...", | |
| ) | |
| paste_name = gr.Textbox(label="Document name", value="pasted_text", max_lines=1) | |
| paste_btn = gr.Button("π₯ Load Text", variant="primary") | |
| with gr.Tab("Demo"): | |
| gr.Markdown("Load the built-in demo text about ConjunctionReservoir itself.") | |
| demo_btn = gr.Button("π§ͺ Load Demo", variant="secondary") | |
| doc_status = gr.Markdown("*No document loaded*", elem_id="doc-status") | |
| gr.Markdown("### βοΈ Settings") | |
| threshold_slider = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.4, step=0.05, | |
| label="Conjunction threshold", | |
| info="Fraction of query terms that must co-appear in a sentence (0=TF-IDF, 1=strict AND)" | |
| ) | |
| model_dropdown = gr.Dropdown( | |
| choices=[ | |
| "Qwen/Qwen2.5-72B-Instruct", | |
| "HuggingFaceH4/zephyr-7b-beta", | |
| "microsoft/Phi-3.5-mini-instruct", | |
| "mistralai/Mistral-Nemo-Instruct-2407", | |
| "meta-llama/Llama-3.2-3B-Instruct", | |
| ], | |
| value=DEFAULT_MODEL, | |
| label="LLM model", | |
| info="HuggingFace Inference API (free)" | |
| ) | |
| hf_token_input = gr.Textbox( | |
| label="HuggingFace token (optional)", | |
| placeholder="hf_...", | |
| type="password", | |
| info="Add for higher rate limits. Get one free at huggingface.co/settings/tokens" | |
| ) | |
| show_retrieval_chk = gr.Checkbox( | |
| label="Show retrieved passages", | |
| value=True, | |
| ) | |
| clear_btn = gr.Button("ποΈ Clear conversation", variant="stop", size="sm") | |
| # ββ Right column: chat βββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(scale=2): | |
| gr.Markdown("### π¬ Chat") | |
| # Gradio 6.0 change: removed bubble_full_width and render_markdown | |
| chatbot = gr.Chatbot( | |
| label="", | |
| height=480, | |
| show_label=False, | |
| ) | |
| retrieval_info = gr.Markdown("", elem_id="retrieval-info") | |
| with gr.Row(): | |
| msg_input = gr.Textbox( | |
| placeholder="Ask anything about your documentβ¦", | |
| show_label=False, | |
| scale=5, | |
| container=False, | |
| ) | |
| send_btn = gr.Button("Send βΆ", variant="primary", scale=1) | |
| gr.Markdown(""" | |
| <small> | |
| **Tip:** Try queries that require two concepts together, e.g. *"NMDA coincidence detection"*. | |
| Commands: type `:coverage` to see sweep focus β’ `:summary` for index stats β’ `:threshold 0.7` to change on-the-fly | |
| </small> | |
| """) | |
| # ββ Callbacks ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_file(filepath, threshold): | |
| if not filepath: | |
| return "*No file selected*", state.chat_history | |
| text = extract_text_from_file(filepath) | |
| if text.startswith("ERROR"): | |
| return f"β {text}", state.chat_history | |
| return _index_text(text, Path(filepath).name, threshold) | |
| def load_paste(text, name, threshold): | |
| if not text or not text.strip(): | |
| return "*No text provided*", state.chat_history | |
| return _index_text(text.strip(), name or "pasted_text", threshold) | |
| def load_demo_cb(threshold): | |
| status = _load_demo() | |
| state.chat_history = [] | |
| state.llm_history = [] | |
| return status, [] | |
| def _index_text(text, name, threshold): | |
| state.reset_doc() | |
| try: | |
| r = ConjunctionReservoir( | |
| conjunction_threshold=float(threshold), | |
| coverage_decay=0.04 | |
| ) | |
| r.build_index(text, verbose=False) | |
| state.retriever = r | |
| state.doc_name = name | |
| state.doc_chars = len(text) | |
| s = r.summary() | |
| status = ( | |
| f"β **{name}** loaded \n" | |
| f"{s['n_chunks']} chunks β’ {s['n_sentences']} sentences β’ " | |
| f"vocab {s['vocab_size']} β’ {s['index_time_ms']:.0f}ms" | |
| ) | |
| return status, [] | |
| except Exception as e: | |
| return f"β Error indexing: {e}", state.chat_history | |
| def clear_chat(): | |
| state.reset_chat() | |
| return [], "" | |
| def handle_command(msg: str): | |
| """Handle special : commands. Returns (response_str, is_command).""" | |
| cmd = msg.strip().lower() | |
| if cmd == ":coverage": | |
| if state.retriever is None: | |
| return "No document loaded.", True | |
| p = state.retriever.coverage_profile() | |
| lines = [f"**Vollan sweep coverage** (after {p['n_queries']} queries) \n"] | |
| lines.append(f"Mean coverage: {p['mean_coverage']:.5f} \n") | |
| if p["most_covered"]: | |
| lines.append("**Most visited sentences:**") | |
| for sent, cov in p["most_covered"][:5]: | |
| lines.append(f"- [{cov:.3f}] {sent[:80]}β¦") | |
| return "\n".join(lines), True | |
| if cmd == ":summary": | |
| if state.retriever is None: | |
| return "No document loaded.", True | |
| s = state.retriever.summary() | |
| return ( | |
| f"**Index summary** \n" | |
| + "\n".join(f"- **{k}**: {v}" for k, v in s.items()) | |
| ), True | |
| if cmd.startswith(":threshold "): | |
| try: | |
| val = float(cmd.split()[1]) | |
| val = max(0.0, min(1.0, val)) | |
| if state.retriever: | |
| state.retriever.conjunction_threshold = val | |
| return f"β Threshold set to **{val:.2f}**", True | |
| except Exception: | |
| return "Usage: `:threshold 0.5`", True | |
| if cmd == ":help": | |
| return ( | |
| "**Commands:**\n" | |
| "- `:coverage` β show Vollan sweep focus\n" | |
| "- `:summary` β index statistics\n" | |
| "- `:threshold N` β set conjunction gate (0.0β1.0)\n" | |
| "- `:help` β this message" | |
| ), True | |
| return "", False | |
| def respond(msg, chat_history, threshold, model, hf_token, show_retrieval): | |
| if not msg or not msg.strip(): | |
| yield chat_history, "" | |
| return | |
| if state.retriever is None: | |
| chat_history.append({"role": "user", "content": msg}) | |
| chat_history.append({"role": "assistant", "content": "β οΈ Please load a document first."}) | |
| yield chat_history, "" | |
| return | |
| # Handle commands | |
| cmd_response, is_cmd = handle_command(msg) | |
| if is_cmd: | |
| chat_history.append({"role": "user", "content": msg}) | |
| chat_history.append({"role": "assistant", "content": cmd_response}) | |
| yield chat_history, "" | |
| return | |
| # Retrieve | |
| q_tokens = set(re.findall(r'\b[a-zA-Z]{3,}\b', msg.lower())) | |
| t0 = time.perf_counter() | |
| hits = do_retrieve(state.retriever, msg, float(threshold)) | |
| elapsed = (time.perf_counter() - t0) * 1000 | |
| retrieval_display = "" | |
| if show_retrieval: | |
| retrieval_display = format_retrieval_display(hits, q_tokens, elapsed) | |
| # Build LLM prompt | |
| context_str = format_context_for_llm(hits) | |
| system = ( | |
| f'You are a document assistant helping the user understand "{state.doc_name}". ' | |
| f'Answer based on the provided passages. Be specific and cite the text when useful. ' | |
| f'If the answer is not in the passages, say so clearly. Keep answers concise.' | |
| ) | |
| user_with_context = ( | |
| f"Question: {msg}\n\n" | |
| f"Relevant passages from the document:\n\n{context_str}" | |
| ) | |
| messages = format_messages(system, state.llm_history[-MAX_HISTORY:], user_with_context) | |
| # Stream response | |
| client = get_client(hf_token) | |
| partial = "" | |
| # Gradio 6 messages format | |
| chat_history.append({"role": "user", "content": msg}) | |
| chat_history.append({"role": "assistant", "content": ""}) | |
| for token in stream_response(client, model, messages): | |
| partial += token | |
| chat_history[-1]["content"] = partial | |
| yield chat_history, retrieval_display | |
| # Save to history | |
| state.llm_history.append((f"Question: {msg}", partial)) | |
| state.chat_history = chat_history | |
| # ββ Wire events ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| upload_btn.click( | |
| load_file, | |
| inputs=[file_input, threshold_slider], | |
| outputs=[doc_status, chatbot], | |
| ) | |
| paste_btn.click( | |
| load_paste, | |
| inputs=[text_input, paste_name, threshold_slider], | |
| outputs=[doc_status, chatbot], | |
| ) | |
| demo_btn.click( | |
| load_demo_cb, | |
| inputs=[threshold_slider], | |
| outputs=[doc_status, chatbot], | |
| ) | |
| clear_btn.click(clear_chat, outputs=[chatbot, retrieval_info]) | |
| send_btn.click( | |
| respond, | |
| inputs=[msg_input, chatbot, threshold_slider, model_dropdown, | |
| hf_token_input, show_retrieval_chk], | |
| outputs=[chatbot, retrieval_info], | |
| ).then(lambda: "", outputs=[msg_input]) | |
| msg_input.submit( | |
| respond, | |
| inputs=[msg_input, chatbot, threshold_slider, model_dropdown, | |
| hf_token_input, show_retrieval_chk], | |
| outputs=[chatbot, retrieval_info], | |
| ).then(lambda: "", outputs=[msg_input]) | |
| # Load demo on startup | |
| demo.load(_load_demo, outputs=[doc_status]) | |
| return demo, css, theme | |
| if __name__ == "__main__": | |
| # Gradio 6.0 change: Pass css and theme into launch() | |
| app, app_css, app_theme = create_app() | |
| app.launch(share=False, css=app_css, theme=app_theme) |