Aluode's picture
Update app.py
07fbe8c verified
"""
ConjunctionReservoir Document Chat β€” HuggingFace Space
=======================================================
Upload any text or PDF document, then ask questions about it.
Retrieval uses sentence-level conjunction scoring (no embeddings needed).
Generation uses HuggingFace Inference API (free, no key required).
"""
import re
import os
import time
import json
import gradio as gr
from pathlib import Path
# ── ConjunctionReservoir ──────────────────────────────────────────────────────
from conjunctionreservoir import ConjunctionReservoir
# ── HuggingFace Inference ─────────────────────────────────────────────────────
from huggingface_hub import InferenceClient
# ── PDF support (optional) ────────────────────────────────────────────────────
try:
import fitz # PyMuPDF
PDF_SUPPORT = True
except ImportError:
try:
import pypdf
PDF_SUPPORT = True
except ImportError:
PDF_SUPPORT = False
# ── Constants ─────────────────────────────────────────────────────────────────
# ── Constants ─────────────────────────────────────────────────────────────────
DEFAULT_MODEL = "Qwen/Qwen2.5-72B-Instruct"
FALLBACK_MODEL = "HuggingFaceH4/zephyr-7b-beta"
MAX_TOKENS = 512
MAX_HISTORY = 6 # turns to keep in context
DEMO_TEXT = """The ConjunctionReservoir is a document retrieval system that asks not
"do these query terms appear somewhere in this chunk?" but rather
"do these query terms appear in the SAME SENTENCE?"
This is grounded in auditory neuroscience. Norman-Haignere et al. (2025)
showed that auditory cortex integration windows are time-yoked at approximately
80ms β€” they are fixed clocks, not expanding to cover arbitrary structure.
The sentence is the text analog of this fixed window.
NMDA receptors implement coincidence detection by requiring simultaneous
presynaptic glutamate release and postsynaptic depolarization to open.
This is a hard AND gate, not a weighted average.
The conjunction_threshold parameter mirrors this: below the threshold,
a sentence contributes zero score to the chunk β€” it is absent, not degraded.
Benchmark results show ConjunctionReservoir achieves 100% Rank-1 Rate on
conjunction-specific queries, compared to 60% for both BM25 and SweepBrain.
It intentionally trades broad-query recall for precision on specific
co-occurrence queries. Use threshold=0.0 to approach standard TF-IDF."""
# ── Text extraction ────────────────────────────────────────────────────────────
def extract_text_from_file(filepath: str) -> str:
"""Extract text from .txt or .pdf file."""
path = Path(filepath)
ext = path.suffix.lower()
if ext == ".pdf":
if not PDF_SUPPORT:
return "ERROR: PDF support not available. Please install PyMuPDF or pypdf."
try:
import fitz
doc = fitz.open(filepath)
return "\n\n".join(page.get_text() for page in doc)
except Exception:
try:
from pypdf import PdfReader
reader = PdfReader(filepath)
return "\n\n".join(p.extract_text() or "" for p in reader.pages)
except Exception as e:
return f"ERROR reading PDF: {e}"
elif ext in (".txt", ".md", ".rst", ".text"):
try:
return path.read_text(encoding="utf-8", errors="replace")
except Exception as e:
return f"ERROR reading file: {e}"
else:
try:
return path.read_text(encoding="utf-8", errors="replace")
except Exception as e:
return f"ERROR: Unsupported file type {ext}. Try .txt or .pdf"
# ── LLM generation ────────────────────────────────────────────────────────────
def get_client(hf_token: str = "") -> InferenceClient:
token = hf_token.strip() or os.environ.get("HF_TOKEN", "")
return InferenceClient(token=token if token else None)
def format_messages(system: str, history: list, user_msg: str) -> list:
messages = [{"role": "system", "content": system}]
for user_h, asst_h in history[-MAX_HISTORY:]:
messages.append({"role": "user", "content": user_h})
messages.append({"role": "assistant", "content": asst_h})
messages.append({"role": "user", "content": user_msg})
return messages
def stream_response(client, model, messages):
"""Stream tokens from HF Inference API."""
try:
stream = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=MAX_TOKENS,
stream=True,
temperature=0.3,
)
for chunk in stream:
# FIX: Check if choices exists before accessing [0]
if chunk.choices and len(chunk.choices) > 0:
delta = chunk.choices[0].delta.content
if delta:
yield delta
except Exception as e:
# Try fallback model
if model != FALLBACK_MODEL:
try:
stream = client.chat.completions.create(
model=FALLBACK_MODEL,
messages=messages,
max_tokens=MAX_TOKENS,
stream=True,
temperature=0.3,
)
for chunk in stream:
# FIX: Check if choices exists before accessing [0]
if chunk.choices and len(chunk.choices) > 0:
delta = chunk.choices[0].delta.content
if delta:
yield delta
return
except Exception:
pass
yield f"\n\n⚠️ Generation error: {e}\n\nTip: Add a HuggingFace token in Settings for better rate limits."
# ── Retrieval helpers ─────────────────────────────────────────────────────────
def best_sentence(chunk: str, q_tokens: set) -> tuple:
sents = [s.strip() for s in re.split(r'[.!?]+', chunk) if len(s.strip()) > 10]
best, best_cov = chunk[:80], 0.0
for s in sents:
toks = set(re.findall(r'\b[a-zA-Z]{3,}\b', s.lower()))
matches = sum(1 for qt in q_tokens if any(qt in t or t in qt for t in toks))
cov = matches / len(q_tokens) if q_tokens else 0.0
if cov > best_cov:
best_cov, best = cov, s
return best, best_cov
def do_retrieve(retriever, query: str, threshold: float, n_chunks: int = 3):
retriever.conjunction_threshold = threshold
hits = retriever.retrieve(query, top_k=n_chunks, update_coverage=True)
hits = [(c, s) for c, s in hits if s > 0]
if not hits:
# Loosen and retry
old = retriever.conjunction_threshold
retriever.conjunction_threshold = 0.0
hits = retriever.retrieve(query, top_k=2, update_coverage=False)
retriever.conjunction_threshold = old
hits = [(c, s) for c, s in hits if s > 0][:2]
return hits
def format_context_for_llm(hits: list) -> str:
if not hits:
return "No relevant passages found."
return "\n\n---\n\n".join(
f"[Passage {i} | relevance {score:.3f}]\n{chunk.strip()}"
for i, (chunk, score) in enumerate(hits, 1)
)
def format_retrieval_display(hits: list, q_tokens: set, elapsed_ms: float) -> str:
if not hits:
return f"⚠️ No passages matched (try lowering threshold) β€’ {elapsed_ms:.0f}ms"
lines = [f"πŸ“š **{len(hits)} passages retrieved** β€’ {elapsed_ms:.0f}ms\n"]
for i, (chunk, score) in enumerate(hits, 1):
sent, cov = best_sentence(chunk, q_tokens)
preview = sent[:120] + ("…" if len(sent) > 120 else "")
lines.append(f"**[{i}]** score={score:.3f} β†’ *\"{preview}\"*")
return "\n".join(lines)
# ── Main app state ─────────────────────────────────────────────────────────────
class AppState:
def __init__(self):
self.retriever = None
self.doc_name = None
self.doc_chars = 0
self.chat_history = [] # list of (user, assistant) for display
self.llm_history = [] # list of (user_with_context, assistant) for LLM
def reset_doc(self):
self.retriever = None
self.doc_name = None
self.doc_chars = 0
self.reset_chat()
def reset_chat(self):
self.chat_history = []
self.llm_history = []
# ── Build the Gradio UI ────────────────────────────────────────────────────────
def create_app():
state = AppState()
# Load demo immediately
def _load_demo():
state.reset_doc()
r = ConjunctionReservoir(conjunction_threshold=0.4, coverage_decay=0.04)
r.build_index(DEMO_TEXT, verbose=False)
state.retriever = r
state.doc_name = "ConjunctionReservoir Demo"
state.doc_chars = len(DEMO_TEXT)
s = r.summary()
return (
f"βœ… **{state.doc_name}** loaded \n"
f"{s['n_chunks']} chunks β€’ {s['n_sentences']} sentences β€’ vocab {s['vocab_size']}"
)
# ── Gradio layout ──────────────────────────────────────────────────────────
css = """
#doc-status { border-left: 4px solid #4CAF50; padding: 8px 12px; background: #f9f9f9; border-radius: 4px; }
#retrieval-info { font-size: 0.85em; color: #555; background: #f5f5f5; padding: 8px; border-radius: 4px; }
.setting-row { display: flex; gap: 12px; align-items: center; }
footer { display: none !important; }
"""
theme = gr.themes.Soft(primary_hue="blue", neutral_hue="slate")
# Gradio 6.0 change: removed css and theme from Blocks init.
with gr.Blocks(
title="ConjunctionReservoir Document Chat",
) as demo:
# ── Header ─────────────────────────────────────────────────────────────
gr.Markdown("""
# 🧠 ConjunctionReservoir Document Chat
**Sentence-level conjunction retrieval** β€” terms must co-appear *in the same sentence* to score.
Grounded in auditory neuroscience (Norman-Haignere 2025, Vollan 2025). Zero embeddings. Millisecond retrieval.
""")
with gr.Row():
# ── Left column: document + settings ──────────────────────────────
with gr.Column(scale=1, min_width=300):
gr.Markdown("### πŸ“„ Document")
with gr.Tab("Upload File"):
file_input = gr.File(
label="Upload .txt or .pdf",
file_types=[".txt", ".pdf", ".md"],
type="filepath",
)
upload_btn = gr.Button("πŸ“₯ Load File", variant="primary")
with gr.Tab("Paste Text"):
text_input = gr.Textbox(
label="Paste your text here",
lines=8,
placeholder="Paste any text...",
)
paste_name = gr.Textbox(label="Document name", value="pasted_text", max_lines=1)
paste_btn = gr.Button("πŸ“₯ Load Text", variant="primary")
with gr.Tab("Demo"):
gr.Markdown("Load the built-in demo text about ConjunctionReservoir itself.")
demo_btn = gr.Button("πŸ§ͺ Load Demo", variant="secondary")
doc_status = gr.Markdown("*No document loaded*", elem_id="doc-status")
gr.Markdown("### βš™οΈ Settings")
threshold_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.4, step=0.05,
label="Conjunction threshold",
info="Fraction of query terms that must co-appear in a sentence (0=TF-IDF, 1=strict AND)"
)
model_dropdown = gr.Dropdown(
choices=[
"Qwen/Qwen2.5-72B-Instruct",
"HuggingFaceH4/zephyr-7b-beta",
"microsoft/Phi-3.5-mini-instruct",
"mistralai/Mistral-Nemo-Instruct-2407",
"meta-llama/Llama-3.2-3B-Instruct",
],
value=DEFAULT_MODEL,
label="LLM model",
info="HuggingFace Inference API (free)"
)
hf_token_input = gr.Textbox(
label="HuggingFace token (optional)",
placeholder="hf_...",
type="password",
info="Add for higher rate limits. Get one free at huggingface.co/settings/tokens"
)
show_retrieval_chk = gr.Checkbox(
label="Show retrieved passages",
value=True,
)
clear_btn = gr.Button("πŸ—‘οΈ Clear conversation", variant="stop", size="sm")
# ── Right column: chat ─────────────────────────────────────────────
with gr.Column(scale=2):
gr.Markdown("### πŸ’¬ Chat")
# Gradio 6.0 change: removed bubble_full_width and render_markdown
chatbot = gr.Chatbot(
label="",
height=480,
show_label=False,
)
retrieval_info = gr.Markdown("", elem_id="retrieval-info")
with gr.Row():
msg_input = gr.Textbox(
placeholder="Ask anything about your document…",
show_label=False,
scale=5,
container=False,
)
send_btn = gr.Button("Send β–Ά", variant="primary", scale=1)
gr.Markdown("""
<small>
**Tip:** Try queries that require two concepts together, e.g. *"NMDA coincidence detection"*.
Commands: type `:coverage` to see sweep focus β€’ `:summary` for index stats β€’ `:threshold 0.7` to change on-the-fly
</small>
""")
# ── Callbacks ──────────────────────────────────────────────────────────
def load_file(filepath, threshold):
if not filepath:
return "*No file selected*", state.chat_history
text = extract_text_from_file(filepath)
if text.startswith("ERROR"):
return f"❌ {text}", state.chat_history
return _index_text(text, Path(filepath).name, threshold)
def load_paste(text, name, threshold):
if not text or not text.strip():
return "*No text provided*", state.chat_history
return _index_text(text.strip(), name or "pasted_text", threshold)
def load_demo_cb(threshold):
status = _load_demo()
state.chat_history = []
state.llm_history = []
return status, []
def _index_text(text, name, threshold):
state.reset_doc()
try:
r = ConjunctionReservoir(
conjunction_threshold=float(threshold),
coverage_decay=0.04
)
r.build_index(text, verbose=False)
state.retriever = r
state.doc_name = name
state.doc_chars = len(text)
s = r.summary()
status = (
f"βœ… **{name}** loaded \n"
f"{s['n_chunks']} chunks β€’ {s['n_sentences']} sentences β€’ "
f"vocab {s['vocab_size']} β€’ {s['index_time_ms']:.0f}ms"
)
return status, []
except Exception as e:
return f"❌ Error indexing: {e}", state.chat_history
def clear_chat():
state.reset_chat()
return [], ""
def handle_command(msg: str):
"""Handle special : commands. Returns (response_str, is_command)."""
cmd = msg.strip().lower()
if cmd == ":coverage":
if state.retriever is None:
return "No document loaded.", True
p = state.retriever.coverage_profile()
lines = [f"**Vollan sweep coverage** (after {p['n_queries']} queries) \n"]
lines.append(f"Mean coverage: {p['mean_coverage']:.5f} \n")
if p["most_covered"]:
lines.append("**Most visited sentences:**")
for sent, cov in p["most_covered"][:5]:
lines.append(f"- [{cov:.3f}] {sent[:80]}…")
return "\n".join(lines), True
if cmd == ":summary":
if state.retriever is None:
return "No document loaded.", True
s = state.retriever.summary()
return (
f"**Index summary** \n"
+ "\n".join(f"- **{k}**: {v}" for k, v in s.items())
), True
if cmd.startswith(":threshold "):
try:
val = float(cmd.split()[1])
val = max(0.0, min(1.0, val))
if state.retriever:
state.retriever.conjunction_threshold = val
return f"βœ… Threshold set to **{val:.2f}**", True
except Exception:
return "Usage: `:threshold 0.5`", True
if cmd == ":help":
return (
"**Commands:**\n"
"- `:coverage` β€” show Vollan sweep focus\n"
"- `:summary` β€” index statistics\n"
"- `:threshold N` β€” set conjunction gate (0.0–1.0)\n"
"- `:help` β€” this message"
), True
return "", False
def respond(msg, chat_history, threshold, model, hf_token, show_retrieval):
if not msg or not msg.strip():
yield chat_history, ""
return
if state.retriever is None:
chat_history.append({"role": "user", "content": msg})
chat_history.append({"role": "assistant", "content": "⚠️ Please load a document first."})
yield chat_history, ""
return
# Handle commands
cmd_response, is_cmd = handle_command(msg)
if is_cmd:
chat_history.append({"role": "user", "content": msg})
chat_history.append({"role": "assistant", "content": cmd_response})
yield chat_history, ""
return
# Retrieve
q_tokens = set(re.findall(r'\b[a-zA-Z]{3,}\b', msg.lower()))
t0 = time.perf_counter()
hits = do_retrieve(state.retriever, msg, float(threshold))
elapsed = (time.perf_counter() - t0) * 1000
retrieval_display = ""
if show_retrieval:
retrieval_display = format_retrieval_display(hits, q_tokens, elapsed)
# Build LLM prompt
context_str = format_context_for_llm(hits)
system = (
f'You are a document assistant helping the user understand "{state.doc_name}". '
f'Answer based on the provided passages. Be specific and cite the text when useful. '
f'If the answer is not in the passages, say so clearly. Keep answers concise.'
)
user_with_context = (
f"Question: {msg}\n\n"
f"Relevant passages from the document:\n\n{context_str}"
)
messages = format_messages(system, state.llm_history[-MAX_HISTORY:], user_with_context)
# Stream response
client = get_client(hf_token)
partial = ""
# Gradio 6 messages format
chat_history.append({"role": "user", "content": msg})
chat_history.append({"role": "assistant", "content": ""})
for token in stream_response(client, model, messages):
partial += token
chat_history[-1]["content"] = partial
yield chat_history, retrieval_display
# Save to history
state.llm_history.append((f"Question: {msg}", partial))
state.chat_history = chat_history
# ── Wire events ────────────────────────────────────────────────────────
upload_btn.click(
load_file,
inputs=[file_input, threshold_slider],
outputs=[doc_status, chatbot],
)
paste_btn.click(
load_paste,
inputs=[text_input, paste_name, threshold_slider],
outputs=[doc_status, chatbot],
)
demo_btn.click(
load_demo_cb,
inputs=[threshold_slider],
outputs=[doc_status, chatbot],
)
clear_btn.click(clear_chat, outputs=[chatbot, retrieval_info])
send_btn.click(
respond,
inputs=[msg_input, chatbot, threshold_slider, model_dropdown,
hf_token_input, show_retrieval_chk],
outputs=[chatbot, retrieval_info],
).then(lambda: "", outputs=[msg_input])
msg_input.submit(
respond,
inputs=[msg_input, chatbot, threshold_slider, model_dropdown,
hf_token_input, show_retrieval_chk],
outputs=[chatbot, retrieval_info],
).then(lambda: "", outputs=[msg_input])
# Load demo on startup
demo.load(_load_demo, outputs=[doc_status])
return demo, css, theme
if __name__ == "__main__":
# Gradio 6.0 change: Pass css and theme into launch()
app, app_css, app_theme = create_app()
app.launch(share=False, css=app_css, theme=app_theme)