DECADE / main.py
anonymous-penguin's picture
Initial code release
9c60174 verified
import argparse
import os
import re
import json
import time
from json import JSONDecodeError
from datetime import datetime, timedelta
from typing import List, Dict, Any
from openai import OpenAI
try:
from openai import AzureOpenAI
from azure.identity import ChainedTokenCredential, AzureCliCredential, ManagedIdentityCredential, get_bearer_token_provider
# Azure scope for the OAuth bearer-token provider; override per deployment.
AZURE_OAUTH_SCOPE = os.environ.get("AZURE_OAUTH_SCOPE", "")
if AZURE_OAUTH_SCOPE:
credential = get_bearer_token_provider(ChainedTokenCredential(
AzureCliCredential(),
ManagedIdentityCredential(),
), AZURE_OAUTH_SCOPE)
else:
credential = None
except ImportError:
AzureOpenAI = None
credential = None
from model_zoo import model_zoo
from memory import EpisodicMemoryStore, SemanticMemoryStore
try:
import tiktoken
except ImportError:
tiktoken = None
try:
from transformers import AutoTokenizer, PreTrainedTokenizerBase
except ImportError:
AutoTokenizer = None
PreTrainedTokenizerBase = ()
from collections import defaultdict
def get_hf_tokenizer_for_vllm(model_name: str):
return AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True)
# Azure OpenAI endpoint (set AZURE_OPENAI_ENDPOINT env var to your deployment URL).
endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", "")
# OpenAI-compatible LiteLLM proxy URL (set LITELLM_BASE_URL env var to your proxy).
TRITONAI_BASE_URL = os.environ.get("LITELLM_BASE_URL", "")
# reading cached files
qid2plan = {}
if 'plan_cache' in os.environ:
plan_cache_file = os.environ['plan_cache']
if os.path.exists(plan_cache_file):
qid2plan = json.load(open(plan_cache_file))
else:
plan_cache_file = 'response_cache/qa/evolv_mem_v3_plan_cache_gpt5-1'
veri_reading_log_file = os.environ['reading_cache']
qid2rel_sess_ids = {}
if os.path.exists(veri_reading_log_file):
qid2rel_sess_ids = json.load(open(veri_reading_log_file))
# Cache file for retrieval results to avoid re-running expensive retrieval operations.
# Stores pre-computed search results for questions, including:
# - Question metadata (id, type, text, answer, dates)
# - Haystack information (session dates, content, IDs)
# - Retrieved results with query, ranked items, and evaluation metrics
# Format: JSONL file where each line contains a complete retrieval result for one question
retrieved_log_file = None
if 'ret_cache' in os.environ:
retrieved_log_file = os.environ['ret_cache']
print("loading existing retrieved results ...")
retrieved_data = [json.loads(line) for line in open(retrieved_log_file).readlines()]
retrieved_data_dict = {x['question_id']: x for x in retrieved_data}
valid_sess_set = set(json.load(open("dataset/all_sessions.json")).keys())
def parse_json(response_content):
"""Safely parse JSON content from a string response."""
candidates = []
if '```json' in response_content:
start_idx = response_content.find('```json') + 7
end_idx = response_content.rfind('```')
if end_idx > start_idx:
candidates.append(response_content[start_idx:end_idx].strip())
# Also try brace-based extraction as fallback
brace_start = response_content.find('{')
brace_end = response_content.rfind('}') + 1
if brace_start >= 0 and brace_end > brace_start:
candidates.append(response_content[brace_start:brace_end].strip())
for json_block in candidates:
try:
result = json.loads(json_block)
return result
except (JSONDecodeError, ValueError):
continue
print(f"[Warning] Failed to decode JSON from response (all strategies failed)")
print(f"[Debug] Raw response content (truncated): {response_content[:500]}")
return {}
def _retryable_status(e):
# Try common attributes first — return any extractable HTTP status code
status = getattr(e, "status_code", None) or getattr(e, "http_status", None)
if status is not None:
return int(status)
resp = getattr(e, "response", None)
if resp is not None and getattr(resp, "status_code", None) is not None:
return int(resp.status_code)
# Fallback: infer from message text
msg = str(e).lower()
if "429" in msg or "rate limit" in msg:
return 429
if "500" in msg or "internal server error" in msg:
return 500
if "503" in msg or "API Configuration unavailable" in msg:
return 503
return None
MAX_CONTEXT_TOKENS = 272_000
#MAX_CONTEXT_TOKENS = 256_000
class CharacterEncoder:
"""Conservative fallback when tokenizer packages are unavailable."""
def encode(self, text, **kwargs):
return list(text)
def decode(self, toks):
return "".join(toks)
def _get_encoder(model_name: str):
"""
Return a token encoder for the given model name.
Cached to avoid reloading HF tokenizers on every call.
"""
# Prefer explicit handling for Qwen models first
# Note: Adjust the path below if you have the model downloaded locally
if AutoTokenizer is not None and any(k in model_name for k in ["Qwen3", "Qwen/", "Qwen"]):
try:
return AutoTokenizer.from_pretrained(
"Qwen/Qwen3-30B-A3B-Instruct-2507",
trust_remote_code=True,
use_fast=False
)
except Exception as e:
print(f"[WARN] Failed to load Qwen tokenizer: {e}. Falling back to tiktoken.")
# For non-Qwen models, rely on tiktoken's mapping when possible
if tiktoken is not None:
try:
return tiktoken.encoding_for_model(model_name)
except Exception:
# Generic safe fallback
return tiktoken.get_encoding("cl100k_base")
return CharacterEncoder()
def _truncate_to_tokens(text, enc, max_tokens) -> str:
"""
Truncates text to the last `max_tokens`.
Compatible with both tiktoken and Hugging Face AutoTokenizers.
"""
# Handle Hugging Face Tokenizers
if PreTrainedTokenizerBase and isinstance(enc, PreTrainedTokenizerBase):
# add_special_tokens=False is crucial here to avoid double counting
# or inserting BOS/EOS in the middle of text during length checks
toks = enc.encode(text, add_special_tokens=False)
# Handle tiktoken
else:
toks = enc.encode(text, disallowed_special=())
if len(toks) <= max_tokens:
return text
# Keep the tail (usually the most relevant for instructions / recent context)
toks = toks[-max_tokens:]
return enc.decode(toks)
def truncate_chat_prompt(tokenizer, messages, max_context, max_output, overhead=256):
# Apply the model's chat template so token counting matches the server
prompt_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
input_ids = tokenizer(prompt_text, add_special_tokens=False).input_ids
budget = max_context - max_output - overhead
if budget < 0:
raise ValueError("max_output_tokens + overhead exceeds max_context_tokens")
if len(input_ids) > budget:
input_ids = input_ids[-budget:] # or keep the *start* depending on your needs
prompt_text = tokenizer.decode(input_ids, skip_special_tokens=False)
return prompt_text
def llm_call(deployment_name: str,
api_version: str,
_prompt: str,
max_context_tokens: int = MAX_CONTEXT_TOKENS,
max_output_tokens: int = 1024,
extra_overhead_tokens: int = 32,
debug: bool = False,
vllm: bool = False,
tritonai: bool = False,
nvidia: bool = False):
if nvidia:
client = OpenAI(
api_key=os.getenv("NV_API_KEY"),
base_url="https://inference-api.nvidia.com/v1",
)
elif tritonai:
client = OpenAI(
api_key=os.getenv("TRITONAI_API_KEY"),
base_url=TRITONAI_BASE_URL,
)
max_context_tokens = 131_072 # DeepSeek R1 128K context
# DeepSeek R1 uses thinking tokens before the answer; raise output budget
if max_output_tokens < 4096:
max_output_tokens = 4096
elif vllm:
# Use local vLLM OpenAI-compatible server
vllm_base_url = os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1")
vllm_api_key = os.getenv("VLLM_API_KEY", "EMPTY")
client = OpenAI(
base_url=vllm_base_url,
api_key=vllm_api_key,
)
# Override deployment_name with the vLLM-served model name when set
# (needed when main model is LiteLLM proxy but reading uses local vLLM)
deployment_name = os.getenv("VLLM_MODEL_NAME", deployment_name)
# vLLM Qwen3-30B-A3B-Instruct-2507 has 131,072-token context
max_context_tokens = 131_072
elif debug:
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
else:
client = AzureOpenAI(
azure_endpoint=endpoint,
azure_ad_token_provider=credential,
api_version=api_version,
)
enc = _get_encoder(deployment_name)
# How many tokens we can spend on the input
budget = max_context_tokens - max_output_tokens - extra_overhead_tokens
if budget < 0:
raise ValueError("max_output_tokens + overhead exceeds max_context_tokens")
prompt_truncated = _truncate_to_tokens(_prompt, enc, budget)
# Strip control characters and fix broken Unicode that break JSON serialization
if nvidia or tritonai:
prompt_truncated = prompt_truncated.encode('utf-8', errors='replace').decode('utf-8', errors='replace')
prompt_truncated = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', prompt_truncated)
# Verify it's valid JSON-serializable
json.dumps(prompt_truncated)
# OpenAI-compatible proxy requires at least one user message; Azure accepts system role
msg_role = "user" if (tritonai or nvidia) else "system"
kwargs = {
'model': deployment_name,
'messages':[
{"role": msg_role, "content": prompt_truncated}
]
}
while True:
try:
completion = client.chat.completions.create(**kwargs)
break
except Exception as e:
from openai import APITimeoutError as _APITimeoutError
if isinstance(e, _APITimeoutError):
print(f"[WARN] APITimeoutError from LLM; sleeping 30s then retrying...", flush=True)
time.sleep(30)
continue
st = _retryable_status(e)
# 404 from LiteLLM proxy/Bedrock is intermittent (model temporarily unavailable)
retryable = (429, 500, 503, 403) + ((404,) if tritonai else ())
if st in retryable:
print(
f"[WARN] q_idx={di} HTTP {st} from LLM; sleeping 60s then retrying...",
flush=True
)
time.sleep(60)
continue
# Non-retryable -> re-raise
print('One exception captured', repr(e), flush=True)
raise
#answer = (completion.choices[0].message.content or "").strip()
#return answer
return completion
def custom_to_iso8601(time_str):
"""
Convert '2023/04/10 (Mon) 23:07' to '2023-04-10T23:07:00'
"""
# Remove the weekday (e.g., "(Mon)")
clean = time_str.split('(')[0].strip() + ' ' + time_str.split(')')[-1].strip()
# Parse the cleaned string
dt = datetime.strptime(clean, "%Y/%m/%d %H:%M")
# Format as ISO 8601
return dt.isoformat()
def evaluate_retrieval(recalled_docs, correct_docs, k=10):
#recalled_docs = set(corpus_ids[idx] for idx in rankings[:k])
recall_any = float(any(doc in recalled_docs for doc in correct_docs))
recall_all = float(all(doc in recalled_docs for doc in correct_docs))
return recall_any, recall_all
def print_average_metrics(retrieval_metric_list):
metric_sums = defaultdict(float)
metric_counts = defaultdict(int)
for metric in retrieval_metric_list:
for k, v in metric.items():
metric_sums[k] += v
metric_counts[k] += 1
print("\t\t Average metrics:")
for k in sorted(metric_sums):
avg = metric_sums[k] / metric_counts[k]
print(f"\t\t {k}: {avg:.4f}")
# Load prompt template
prompt_path = "prompts/agentic_retrieval_prompt.txt"
with open(prompt_path, "r", encoding="utf-8") as f:
stg_prompt = f.read()
class ChatHistory:
def __init__(self, data: Dict[str, Any] = None, sessions: List = None):
assert not (data is not None and sessions is not None), "ChatHistory: Only one of data or sessions may be provided."
if data is not None: # From raw data dict
self.raw_data = data
self.sessions = []
self._parse_sessions()
elif sessions is not None: # From provided sessions list
self.sessions = sessions
self.messages = []
for sess in self.sessions:
session_id = sess['session_id']
timestamp = sess['timestamp']
for turn_idx, msg in enumerate(sess['session']):
entry = {
"role": msg.get("role"),
"content": msg.get("content"),
"session_id": session_id,
"turn_index": turn_idx,
"timestamp": timestamp,
"iso_datetime": timestamp.isoformat(),
"has_answer": msg.get("has_answer", False)
}
self.messages.append(entry)
else:
self.sessions = []
self.messages = []
def get_contents(self, granularity='turn', _format='json') -> list:
if granularity == "turn":
if _format == "json":
return [json.dumps(msg) for msg in self.messages]
else:
return [msg['content'] for msg in self.messages]
else: # granularity == 'session'
if _format == "json":
return [json.dumps(session) for session in self.sessions]
else:
return [json.dumps({"role": session["role"], "content": session["content"]})
for session in self.sessions]
def to_prompt(self, granularity='session', _format="json"):
history_str = ""
for session in self.sessions:
sess_str = json.dumps([{"role": x["role"], "content": x["content"]} for x in session['session']])
history_str += f"Session Date: {session['session_date']}\nSession Content:\n{sess_str}\n"
return history_str
def get_session_ids(self):
return [s['session_id'] for s in self.sessions]
@staticmethod
def _parse_date(date_str: str) -> datetime:
# Convert '2023/04/10 (Mon) 17:50' to datetime
# Remove weekday in parentheses
date_part, time_part = date_str.split('(')[0].strip(), date_str.split(')')[-1].strip()
dt = datetime.strptime(date_part + time_part, "%Y/%m/%d%H:%M")
return dt
def _parse_sessions(self):
"""
Flattens sessions into a list of messages, each with ISO 8601 date, session ID, and turn index
"""
self.messages = []
for date_str, session_id, session, topic in zip(
self.raw_data['haystack_dates'],
self.raw_data['haystack_session_ids'],
self.raw_data['haystack_sessions'],
self.raw_data['haystack_topics']
):
timestamp = self._parse_date(date_str)
for turn_idx, msg in enumerate(session):
entry = {
"role": msg.get("role"),
"content": msg.get("content"),
"session_id": session_id,
"turn_index": turn_idx,
"timestamp": timestamp,
"iso_datetime": timestamp.isoformat(),
"session_date": date_str,
"has_answer": msg.get("has_answer", False)
}
self.messages.append(entry)
self.sessions.append({
"session_date": date_str,
"timestamp": timestamp,
"session_id": session_id,
"session": session,
"topic": topic,
})
# Optionally, sort by time (ascending)
#self.sessions.sort(key=lambda x: x["timestamp"])
#self.messages.sort(key=lambda x: x["timestamp"])
def __len__(self):
return len(self.sessions)
def __getitem__(self, idx) -> Dict[str, any]:
return self.sessions[idx] # Return session dict
def get_item_by_index(self, idx):
if isinstance(idx, range) or isinstance(idx, list):
max_idx = len(self.sessions)
valid_indices = [i for i in idx if 0 <= i < max_idx]
selected_sessions = [self.sessions[i] for i in valid_indices]
return ChatHistory(sessions=selected_sessions)
else:
raise ValueError("Input must be a list or range of indices.")
def get_item_by_session_ids(self, sess_set):
if not isinstance(sess_set, set):
sess_set = set(sess_set)
new_sessions = []
for sess in self.sessions:
if sess['session_id'] in sess_set:
new_sessions.append(sess)
return ChatHistory(sessions=new_sessions)
def get_item_by_ranked_session(self, sess_id_sorted):
new_sessions = []
for sess_id in sess_id_sorted:
for sess in self.sessions:
if sess['session_id'] in sess_id:
new_sessions.append(sess)
return ChatHistory(sessions=new_sessions)
def get_item_by_topics(self, topics):
new_sessions = []
new_sess_ids = set()
for sess in self.sessions:
for tp in sess['topic']:
if tp in topics and sess['session_id'] not in new_sess_ids:
new_sessions.append(sess)
new_sess_ids.add(sess['session_id'])
break
return ChatHistory(sessions=new_sessions)
def merge_rel_sess(self, new_sessions):
# Gather all current and new sessions in a dict keyed by session_id
all_sessions = {s["session_id"]: s for s in self.sessions}
# add if new
for s in new_sessions.sessions:
if s["session_id"] not in all_sessions:
all_sessions[s["session_id"]] = s
# Reconstruct raw_data for new ChatHistory
merged_raw_data = {
"haystack_dates": [s["session_date"] for k, s in all_sessions.items()],
"haystack_session_ids": [s["session_id"] for k, s in all_sessions.items()],
"haystack_sessions": [s["session"] for k, s in all_sessions.items()],
"haystack_topics": [s["topic"] for k, s in all_sessions.items()],
}
self.sessions = ChatHistory(merged_raw_data)
def generate_keywords(question: str, deployment_name, api_version, debug=False, vllm=None,
tritonai=False, nvidia=False) -> List[str]:
# read prompt from `keyword_search_prompt.txt` file
with open('prompts/keyword_search_prompt.txt') as f:
prompt_template = f.read()
prompt = prompt_template + question
# Call the LLM to generate keywords
completion = llm_call(
deployment_name,
api_version,
prompt,
debug=debug,
vllm=vllm,
tritonai=tritonai,
nvidia=nvidia,
)
response_content = (completion.choices[0].message.content or "").strip()
result = parse_json(response_content)
keywords = result["keywords"] if "keywords" in result else []
return keywords
def keyword_search(chat_history: ChatHistory, keywords: list):
print(f"\t\t** keyword search **: {keywords}")
# Gather all messages that match
start_time = time.time()
matched_msgs = [
msg for msg in chat_history.messages
if any(kw.lower() in (msg.get("content") or "").lower() for kw in keywords)
]
end_time = time.time()
execution_time = end_time - start_time
if matched_msgs:
new_sess_ids = set()
for msg in matched_msgs:
key = msg["session_id"]
new_sess_ids.add(key)
new_chat_history = chat_history.get_item_by_session_ids(new_sess_ids)
else:
new_chat_history = ChatHistory()
return new_chat_history
def is_turn_id(text):
pattern = r'_\d+$'
return bool(re.search(pattern, text))
def embedding_search(chat_history: ChatHistory, qid: str, top_k: int = 50, exclude_sess=None):
print("\t\t** embedding based retrieval **")
new_sess_ids = []
curr_all_sess = set(chat_history.get_session_ids())
if exclude_sess:
curr_all_sess = curr_all_sess - set(exclude_sess)
for item in retrieved_data_dict[qid]["retrieval_results"]["ranked_items"]:
if item["corpus_id"] in valid_sess_set:
sid = item["corpus_id"]
else:
tokens = item["corpus_id"].split("_")
if "_turn" in item["corpus_id"]:
sid = item["corpus_id"].split("_turn")[0]
elif "_fact" in item["corpus_id"]:
sid = item["corpus_id"].split("_fact")[0]
elif "noans" in item["corpus_id"]:
sid = item["corpus_id"].replace("noans", "answer")
elif is_turn_id(item["corpus_id"]):
sid = "_".join(tokens[:-1]) # remove turn index
else:
sid = item["corpus_id"]
if sid not in valid_sess_set:
print(item["corpus_id"], sid)
assert sid in valid_sess_set
if sid in curr_all_sess:
new_sess_ids.append(sid)
if len(new_sess_ids) == top_k:
break
new_chat_history = chat_history.get_item_by_ranked_session(new_sess_ids)
return new_chat_history
def filter_out_by_embedding(chat_history: ChatHistory, qid: str, top_k: int = 50):
print("\t\t** [filter_out] embedding based retrieval - loading existing results ...")
new_sess_ids = []
curr_all_sess = set(chat_history.get_session_ids())
for item in retrieved_data_dict[qid]["retrieval_results"]["ranked_items"]:
if item["corpus_id"] in valid_sess_set:
sid = item["corpus_id"]
else:
tokens = item["corpus_id"].split("_")
if "_turn" in item["corpus_id"]:
sid = item["corpus_id"].split("_turn")[0]
elif "_fact" in item["corpus_id"]:
sid = item["corpus_id"].split("_fact")[0]
elif "noans" in item["corpus_id"]:
sid = item["corpus_id"].replace("noans", "answer")
elif is_turn_id(item["corpus_id"]):
sid = "_".join(tokens[:-1]) # remove turn index
else:
sid = item["corpus_id"]
if sid not in valid_sess_set:
print(item["corpus_id"], sid)
assert sid in valid_sess_set
if sid in curr_all_sess:
new_sess_ids.append(sid)
if len(new_sess_ids) == top_k:
break
new_chat_history = chat_history.get_item_by_ranked_session(new_sess_ids)
return new_chat_history, 0.0
def flat_embedding_top_k_ids(qid: str, haystack_sess_ids: List[str], top_k: int) -> List[str]:
"""
Pull the top_k session IDs from the global GTE retrieval cache (retrieved_data_dict),
constrained to the question's haystack. Mirrors embedding_search() but operates on
IDs only (no ChatHistory). Used by hier_union to widen the Stage-2 pool.
"""
haystack_set = set(haystack_sess_ids)
ids: List[str] = []
for item in retrieved_data_dict[qid]["retrieval_results"]["ranked_items"]:
cid = item["corpus_id"]
if cid in valid_sess_set:
sid = cid
elif "_turn" in cid:
sid = cid.split("_turn")[0]
elif "_fact" in cid:
sid = cid.split("_fact")[0]
elif "noans" in cid:
sid = cid.replace("noans", "answer")
elif is_turn_id(cid):
sid = "_".join(cid.split("_")[:-1])
else:
sid = cid
if sid in haystack_set and sid not in ids:
ids.append(sid)
if len(ids) == top_k:
break
return ids
def semantic_embedding_search(
qid: str,
haystack_sess_ids: List[str],
semantic_retrieved_dict: dict,
top_k: int = 50,
) -> List[str]:
"""
Like embedding_search() but reads from the pre-computed semantic-gte retrieval cache.
Returns an ordered list of up to top_k session IDs from the haystack.
"""
print("\t\t** semantic embedding retrieval **")
haystack_set = set(haystack_sess_ids)
ranked_ids: List[str] = []
for item in semantic_retrieved_dict[qid]["retrieval_results"]["ranked_items"]:
sid = item["corpus_id"] # already session-level (no turn suffix)
if sid in haystack_set and sid not in ranked_ids:
ranked_ids.append(sid)
if len(ranked_ids) == top_k:
break
return ranked_ids
def time_filter(chat_history: ChatHistory, start_date: str, end_date: str) -> ChatHistory:
# Returns all messages with timestamp in the ISO date range [start_date, end_date] (inclusive).
start_time = time.time()
try:
start = datetime.fromisoformat(start_date)
end = datetime.fromisoformat(end_date)
filtered_msgs = [msg for msg in chat_history.messages if start.date() <= msg["timestamp"].date() <= end.date()]
except Exception as e:
print("Converting date error: ", e)
filtered_msgs = []
end_time = time.time()
execution_time = end_time - start_time
if filtered_msgs:
new_sess_ids = set()
for msg in filtered_msgs:
key = msg["session_id"]
new_sess_ids.add(key)
new_chat_history = chat_history.get_item_by_session_ids(new_sess_ids)
else:
new_chat_history = ChatHistory()
return new_chat_history
class RetrievalAgent:
def __init__(
self,
history: List[Dict],
topics: List[str],
user_profile: str = None,
debug: bool = False,
vllm: bool = False,
vllm_reading: bool = False,
tritonai: bool = False,
nvidia: bool = False,
n_chunks: int = 10,
topic_filter: bool = True,
no_time_filter: bool = False,
semantic_store: SemanticMemoryStore = None,
episodic_store: EpisodicMemoryStore = None,
hier_v2: bool = False,
hier_union: bool = False,
hier_union_flat_k: int = 20,
no_early_answer: bool = False,
):
self.chat_history = history
self.user_profile = user_profile
self.topics = topics
self.rel_sess = ChatHistory()
self.evidence = []
self.debug = debug
self.vllm = vllm
self.vllm_reading = vllm_reading # use vLLM only for _read_and_verify
self.tritonai = tritonai # use LiteLLM proxy for non-reading LLM calls
self.nvidia = nvidia # use NVIDIA inference API
self.no_time_filter = no_time_filter # skip time_filter steps in strategy
self.n_chunks = n_chunks
self.topic_filter = topic_filter
self.semantic_store = semantic_store
self.episodic_store = episodic_store
self.hier_v2 = hier_v2
self.hier_union = hier_union
self.hier_union_flat_k = hier_union_flat_k
self.no_early_answer = no_early_answer
self.token_budget = {
'planning': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0},
'verification_reading': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0},
'is_answerable': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0},
'final_answer': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0},
}
if self.user_profile:
with open('prompts/read_and_extract_prompt.txt') as f:
self.read_prompt_template = f.read()
else: # ablation: wo_profile
with open('prompts/agentic_retrieval_prompt_wo_profile.txt') as f:
self.read_prompt_template = f.read()
def _track_usage(self, component: str, completion) -> None:
"""Accumulate prompt/completion token counts for a named component."""
usage = getattr(completion, 'usage', None)
if usage is None:
return
self.token_budget[component]['prompt_tokens'] += getattr(usage, 'prompt_tokens', 0) or 0
self.token_budget[component]['completion_tokens'] += getattr(usage, 'completion_tokens', 0) or 0
self.token_budget[component]['n_calls'] += 1
def get_token_budget(self) -> dict:
"""Return token_budget with an added 'total' entry."""
total = {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0}
for v in self.token_budget.values():
total['prompt_tokens'] += v['prompt_tokens']
total['completion_tokens'] += v['completion_tokens']
total['n_calls'] += v['n_calls']
return {**self.token_budget, 'total': total}
def is_answerable(self, question: str, question_date: str, retrieved_sess, evidence, model_info, context_str: str = None) -> bool:
context = ""
for k, v in evidence.items(): # key: "profile", "tags", "chat_clues"
for e in v:
context += f"{e}\n"
# context_str overrides retrieved_sess.to_prompt() (used in Stage 1 with semantic context)
sess_str = context_str if context_str is not None else retrieved_sess.to_prompt()
# Include user profile in the prompt when available
profile_section = ""
if self.user_profile:
profile_section = f"\nUser Profile:\n{self.user_profile}\n"
ia_prompt_prefix = f"""
You are a decision-making agent tasked with determining when sufficient information has been gathered to answer a user's question.
Your Task:
Analyze the provided question, current date, available memory context, and available evidence to make a binary decision: Answerable or Not answerable. If the information is not sufficient, explain what specific information is needed to provide hints for the next retrieval stage.
Question: {question}
Current Date: {question_date}
{profile_section}
Memory Context:
"""
output_str = """
Output (always JSON — choose fields per the rules above)
Case 1 — Answerable:
{
"is_answerable": true,
"answer": "<concise answer grounded strictly in the Evidence>"
}
Case 2 — Not answerable:
{
"is_answerable": false,
"info_needed": ["<specific missing detail 1>", "<specific missing detail 2>"]
}
"""
deployment_name, api_version = model_info
# ------------------------------------------------------------------
# Token-based truncation: keep the *end* of sess_str under budget
# ------------------------------------------------------------------
enc = _get_encoder(deployment_name)
# Model/context limits
if self.vllm or self.tritonai:
# Large context for vLLM / LiteLLM proxy models
model_max_ctx = 131_072
else:
model_max_ctx = MAX_CONTEXT_TOKENS
max_output_tokens = 1024
extra_overhead_tokens = 32
# Total budget available for input tokens
budget = model_max_ctx - max_output_tokens - extra_overhead_tokens
if budget <= 0:
raise ValueError(
f"max_output_tokens ({max_output_tokens}) + overhead "
f"({extra_overhead_tokens}) exceeds model_max_ctx ({model_max_ctx})"
)
# Token lengths of static pieces
prefix_tokens = enc.encode(ia_prompt_prefix, disallowed_special=())
output_tokens = enc.encode(output_str, disallowed_special=())
sess_tokens = enc.encode(sess_str, disallowed_special=())
# Budget for sess_str tokens
available_for_sess = budget - len(prefix_tokens) - len(output_tokens)
if available_for_sess <= 0:
# No room for history at all; drop it
truncated_sess_str = ""
else:
if len(sess_tokens) > available_for_sess:
# Keep the *last* available_for_sess tokens (drop oldest history)
truncated_sess_tokens = sess_tokens[-available_for_sess:]
truncated_sess_str = enc.decode(truncated_sess_tokens)
else:
truncated_sess_str = sess_str
ia_prompt = ia_prompt_prefix + truncated_sess_str + "\n"
# ------------------------------------------------------------------
# Call the LLM with the already-truncated prompt
# ------------------------------------------------------------------
completion = llm_call(
deployment_name,
api_version,
ia_prompt + output_str,
max_context_tokens=model_max_ctx, # matches what we used for budgeting
max_output_tokens=max_output_tokens,
extra_overhead_tokens=extra_overhead_tokens,
debug=self.debug,
vllm=self.vllm,
tritonai=self.tritonai,
nvidia=self.nvidia,
)
self._track_usage('is_answerable', completion)
response_content = (completion.choices[0].message.content or "").strip()
print(f"\t\t[Agent] is_answerable: {response_content}")
result = parse_json(response_content)
if not result:
print("[Warning] Empty or invalid JSON in is_answerable() response.")
return {"is_answerable": False, "info_needed": ["Parsing failed"]}
return result
def _read_and_verify(self, question: str, question_date: str, evidence: ChatHistory, n_chunks=10) -> ChatHistory:
# Read evidence and and select
relevant_indices = []
evidence_list = []
max_idx = len(evidence)
for j in range(0, len(evidence), n_chunks):
chunk_range = range(j, j+n_chunks)
valid_indices = [i for i in chunk_range if 0 <= i < max_idx]
cur_chunk = evidence.get_item_by_index(valid_indices)
cur_chunk_sess = [[{"role": m["role"], "content": m["content"]} for m in sess['session']]
for sess in cur_chunk.sessions]
cur_chunk_sess_date = [sess['session_date'] for sess in cur_chunk.sessions]
sess_input_str = "\n".join([
f"### Session Index: {i}\n### Session Date: {sess_date}\n\n{json.dumps(sess)}\n"
for i, (sess, sess_date) in enumerate(zip(cur_chunk_sess, cur_chunk_sess_date))
])
_prompt = self.read_prompt_template + f"## Question: {question}\n## Question Date: {question_date}\n## Session list:\n\n{sess_input_str}\nNow, identify **only the sessions strictly necessary to answer the question**."
completion = llm_call(
deployment_name,
api_version,
_prompt,
debug=self.debug,
vllm=self.vllm or self.vllm_reading,
nvidia=self.nvidia,
)
self._track_usage('verification_reading', completion)
response_content = (completion.choices[0].message.content or "").strip()
print(f"\t\t {valid_indices[0]}~{valid_indices[-1]}: response: {response_content.replace(chr(10), '')}")
try:
start_idx = response_content.rfind('{')
end_idx = response_content.rfind('}') + 1
json_block = response_content[start_idx:end_idx]
result = json.loads(json_block)
if "index" in result and result['index'] and 'evidence' and result:
relevant_indices.extend([j + idx for idx in result['index']])
evidence_list.extend(result['evidence'])
except Exception as e:
print(f"Error parsing LLM response: {e}")
if relevant_indices:
return evidence.get_item_by_index(relevant_indices), evidence_list
else:
return ChatHistory(), []
def _read_and_verify_with_cache(self, qid:str, pool):
relevant_sess_ids = []
for sess in pool.sessions:
sess_id = sess['session_id']
if sess_id in qid2rel_sess_ids[qid]:
relevant_sess_ids.append(sess_id)
if len(relevant_sess_ids) > 0:
return pool.get_item_by_session_ids(relevant_sess_ids), []
else:
return ChatHistory(), []
def _plan(self, query: str, query_date: str, attempt_record: list, model_info) -> str:
if self.user_profile:
template = """
### User profile: {user_profile}
### Chat history topics: {topics}
### User query: {query}
### User query date: {query_date}
### Previous attempts:
{strategies_info}
"""
else:
template = """
### Chat history topics: {topics}
### User query: {query}
### User query date: {query_date}
### Previous attempts:
{strategies_info}
"""
if attempt_record:
strategies_info = ""
for loop_num, entry in enumerate(attempt_record):
strategies_info += f"\nloop_iteration: {loop_num+1}\n"
strategies_info += "\n".join(entry.get('step_logs', []))
if 'n_retrieved_sess' in entry and 'evidence' in entry:
strategies_info += f"Retrieved {entry['n_retrieved_sess']} docs, observed_evidence: {entry['evidence']}"
if entry['n_retrieved_sess'] == 0:
strategies_info += f"Additional Instruction: Re-try without filter methods if the previous paln includes topics or time-filtering\n"
else:
strategies_info = "(No previous attempt exists)"
if self.user_profile:
prompt_filled = stg_prompt + template.format(
user_profile=self.user_profile,
topics=",".join(self.topics),
query=query,
query_date=query_date,
strategies_info=strategies_info)
else:
prompt_filled = stg_prompt + template.format(
topics=",".join(self.topics),
query=query,
query_date=query_date,
strategies_info=strategies_info)
deployment_name, api_version = model_info
completion = llm_call(
deployment_name,
api_version,
prompt_filled,
debug=self.debug,
vllm=self.vllm,
tritonai=self.tritonai,
nvidia=self.nvidia,
)
self._track_usage('planning', completion)
response_content = (completion.choices[0].message.content or "").strip()
_plan = parse_json(response_content)
if not _plan:
print("[Warning] Failed to parse plan JSON — retrying once.")
completion = llm_call(
deployment_name,
api_version,
prompt_filled,
debug=self.debug,
vllm=self.vllm,
tritonai=self.tritonai,
nvidia=self.nvidia,
)
self._track_usage('planning', completion)
response_content = (completion.choices[0].message.content or "").strip()
_plan = parse_json(response_content)
if not _plan:
print("[Warning] Failed to parse plan JSON after retry — returning fallback plan.")
_plan = {"answer": "none", "reason": "invalid JSON response", "topics": [], "strategy": []}
return _plan
def _run_stage1(
self,
qid: str,
question: str,
question_date: str,
top_k: int,
model_info,
haystack_sess_ids: List[str],
date_lookup: Dict[str, str],
semantic_ret_dict: dict,
) -> dict:
"""
Stage 1: retrieve and evaluate using semantic memory only (summaries + facts).
Returns a dict with:
is_answerable : bool
answer : str | None (set when is_answerable is True)
candidate_ids : list[str] (top-K session IDs from semantic retrieval)
attempt_record : list
"""
print(f"\t[Stage 1] Semantic memory retrieval for qid={qid}")
# --- 1a. Semantic embedding search ---
candidate_ids = semantic_embedding_search(
qid, haystack_sess_ids, semantic_ret_dict, top_k=top_k
)
# hier_v2: skip plan / keyword / time_filter / is_answerable. Stage 1 is candidate-only.
if self.hier_v2:
print(f"\t[Stage 1 hier_v2] embedding-only candidates: {len(candidate_ids)}")
return {
"is_answerable": False,
"answer": None,
"candidate_ids": candidate_ids[:top_k],
"attempt_record": [{
"stage": "semantic_v2",
"plan": {},
"n_candidates": len(candidate_ids),
"candidate_ids": candidate_ids[:top_k],
}],
}
# --- 1b. Plan for keywords / time filter (reuse existing _plan) ---
plan = self._plan(question, question_date, [], model_info)
print(json.dumps(plan, indent=4), flush=True)
# If the planner already has a direct answer, return it
if "answer" in plan and plan["answer"].lower() != "none":
return {
"is_answerable": True,
"answer": plan["answer"],
"candidate_ids": candidate_ids,
"attempt_record": [{"plan": plan, "stage": "semantic"}],
}
# --- 1c. Semantic keyword search ---
keyword_ids: List[str] = []
for step in plan.get("strategy", []):
if step.get("method") == "keyword":
kws = step.get("keywords", [])
matched = self.semantic_store.keyword_search(kws, haystack_sess_ids)
print(f"\t\t** semantic keyword search **: {kws} -> {len(matched)} matches")
keyword_ids.extend(sid for sid in matched if sid not in keyword_ids)
# --- 1d. Time filter on candidate set ---
for step in plan.get("strategy", []):
if self.no_time_filter:
break # skip all time_filter steps
if step.get("method") == "time_filter":
if "time_range" not in step or len(step["time_range"]) != 2:
continue
start_str, end_str = step["time_range"]
from datetime import datetime
try:
start_dt = datetime.fromisoformat(start_str)
end_dt = datetime.fromisoformat(end_str)
candidate_ids = [
sid for sid in candidate_ids
if sid in date_lookup and
start_dt.date() <= EpisodicMemoryStore._parse_date(date_lookup[sid]).date() <= end_dt.date()
]
keyword_ids = [
sid for sid in keyword_ids
if sid in date_lookup and
start_dt.date() <= EpisodicMemoryStore._parse_date(date_lookup[sid]).date() <= end_dt.date()
]
print(f"\t\t** semantic time_filter **: {start_str}..{end_str} -> "
f"{len(candidate_ids)} embed, {len(keyword_ids)} keyword")
except Exception as e:
print(f"\t\t[WARN] time_filter parse error: {e}")
# --- 1e. Merge candidates (keyword union with embedding, preserve rank) ---
all_candidate_ids: List[str] = list(candidate_ids)
for sid in keyword_ids:
if sid not in all_candidate_ids:
all_candidate_ids.append(sid)
# Cap to top_k
all_candidate_ids = all_candidate_ids[:top_k]
if not all_candidate_ids:
print("\t[Stage 1] No candidates found in semantic memory.")
return {
"is_answerable": False,
"answer": None,
"candidate_ids": [],
"attempt_record": [{"plan": plan, "stage": "semantic", "n_candidates": 0}],
}
# --- 1f. Build semantic context string ---
semantic_context_str = self.semantic_store.to_prompt(all_candidate_ids, date_lookup)
print(f"\t[Stage 1] Built semantic context for {len(all_candidate_ids)} sessions "
f"({len(semantic_context_str)} chars)")
# --- 1g. is_answerable check on semantic context ---
accumulated_evidence = {"profile": [], "chat_clues": []}
answerable_response = self.is_answerable(
question, question_date,
retrieved_sess=None,
evidence=accumulated_evidence,
model_info=model_info,
context_str=semantic_context_str,
)
print(f"\t[Stage 1] is_answerable: {answerable_response}")
attempt_record = [{
"stage": "semantic",
"plan": plan,
"n_candidates": len(all_candidate_ids),
"candidate_ids": all_candidate_ids,
"is_answerable": answerable_response.get("is_answerable", False),
}]
if answerable_response.get("is_answerable"):
return {
"is_answerable": True,
"answer": answerable_response.get("answer"),
"candidate_ids": all_candidate_ids,
"attempt_record": attempt_record,
}
print(f"\t[Stage 1] Not answerable from semantic memory. "
f"Info needed: {answerable_response.get('info_needed', [])}")
return {
"is_answerable": False,
"answer": None,
"candidate_ids": all_candidate_ids,
"attempt_record": attempt_record,
}
def run(self, qid:str, question: str, question_date: str, top_k: int, model_info, max_loops=3,
semantic_ret_dict: dict = None, haystack_sess_ids: List[str] = None,
date_lookup: Dict[str, str] = None, topic_lookup: Dict[str, List[str]] = None):
accumulated_evidence = {"profile": [], "chat_clues": []}
attempt_record = []
loop_num = 0
# ----------------------------------------------------------------
# Stage 1: Semantic memory — only runs when stores are provided
# ----------------------------------------------------------------
stage1_candidate_ids: List[str] = []
if (self.semantic_store is not None
and semantic_ret_dict is not None
and haystack_sess_ids is not None):
stage1_result = self._run_stage1(
qid, question, question_date, top_k, model_info,
haystack_sess_ids, date_lookup or {}, semantic_ret_dict,
)
attempt_record.extend(stage1_result["attempt_record"])
stage1_candidate_ids = stage1_result["candidate_ids"]
if stage1_result["is_answerable"] and not self.no_early_answer:
print(f"\t[Stage 1] Answered from semantic memory.")
# Wrap answer in attempt_record format expected by caller
if "answer" in stage1_result and stage1_result["answer"]:
attempt_record[0]["plan"] = {
**attempt_record[0].get("plan", {}),
"answer": stage1_result["answer"],
}
return ChatHistory(), attempt_record
if stage1_result["is_answerable"] and self.no_early_answer:
print(f"\t[Stage 1] is_answerable=True but --no_early_answer set; proceeding to Stage-2.")
# hier_union: widen Stage-2 pool with flat-embedding top-K from the global GTE cache.
# This makes hier a strict superset of flat by construction; targets the recall gap
# (semantic-over-summary embeddings rank worse than full-session embeddings).
if self.hier_union and qid in retrieved_data_dict:
flat_ids = flat_embedding_top_k_ids(qid, haystack_sess_ids, self.hier_union_flat_k)
before = len(stage1_candidate_ids)
for sid in flat_ids:
if sid not in stage1_candidate_ids:
stage1_candidate_ids.append(sid)
print(f"\t[hier_union] semantic_top_k={before} + flat_top_{self.hier_union_flat_k}={len(flat_ids)} -> union={len(stage1_candidate_ids)}")
attempt_record.append({
"stage": "hier_union",
"plan": {},
"n_semantic": before,
"n_flat": len(flat_ids),
"n_union": len(stage1_candidate_ids),
})
# Stage 2: load episodic sessions only for top-K candidates
if self.episodic_store is not None and stage1_candidate_ids:
print(f"\t[Stage 2] Loading episodic memory for "
f"{len(stage1_candidate_ids)} candidate sessions.")
raw_sessions = self.episodic_store.get_raw_sessions(
stage1_candidate_ids, date_lookup or {}, topic_lookup
)
self.chat_history = ChatHistory(sessions=raw_sessions)
print(f"\t[Stage 2] Loaded {len(self.chat_history)} sessions into episodic pool.")
# hier_v2: skip the agent loop. Use raw turns of the candidate sessions directly.
# Rationale: the agent loop's verification can reject all of a 20-session pool,
# leaving empty retrieved. Strong models answer better from raw turns of K=20
# semantic-selected sessions than from over-aggressive verification.
if self.hier_v2:
print(f"\t[hier_v2] bypassing agent loop; returning {len(self.chat_history)} candidate sessions as retrieved")
return self.chat_history, attempt_record
pool = self.chat_history
retrieved = ChatHistory()
while loop_num < max_loops:
loop_num += 1
if loop_num == 1 and qid in qid2plan:
plan = qid2plan[qid]
else:
plan = self._plan(question, question_date, attempt_record, model_info)
qid2plan[qid] = plan
print(json.dumps(plan, indent=4), flush=True)
if "answer" in plan and not ("none" in plan["answer"].lower()):
print(f"{qid}\t{question}\t{plan['answer']}", flush=True)
return ChatHistory(), [{"plan": plan}]
# 2) Execute the plan -> retrieve candidates
try:
candidates, step_logs = self._execute_strategy(pool, plan, question)
except Exception as e:
print(f"[Error] Failed during _execute_strategy: {e}")
candidates, step_logs = ChatHistory(), [f"Execution failed: {e}"]
if len(candidates) == 0:
attempt_record.append({
"loop_iteration": loop_num,
"plan": plan,
"evidence": accumulated_evidence,
"n_candidates_sess": len(candidates),
"n_verified_sess": 0,
"n_pool": len(pool),
"step_logs": step_logs,
})
continue
else:
remaining = set(pool.get_session_ids()) - set(candidates.get_session_ids())
pool = self.chat_history.get_item_by_session_ids(remaining)
# 3) Verification Reading
if qid in qid2rel_sess_ids:
verified, evidence_list = self._read_and_verify_with_cache(qid, candidates)
else:
verified, evidence_list = self._read_and_verify(question, question_date, candidates, n_chunks=self.n_chunks)
qid2rel_sess_ids[qid] = verified.get_session_ids()
if len(verified) == 0:
attempt_record.append({
"loop_iteration": loop_num,
"plan": plan,
"evidence": accumulated_evidence,
"n_candidates_sess": len(candidates),
"candidates_sess_ids": candidates.get_session_ids(),
"n_verified_sess": len(verified),
"verified_sess_ids": [],
"n_pool": len(pool),
"step_logs": step_logs,
})
continue
else:
retrieved.merge_rel_sess(verified)
for ev in evidence_list:
if ev not in accumulated_evidence['chat_clues']:
accumulated_evidence['chat_clues'].append(ev)
attempt_record.append({
"loop_iteration": loop_num,
"plan": plan,
"evidence": accumulated_evidence,
"n_candidates_sess": len(candidates),
"candidates_sess_ids": candidates.get_session_ids(),
"n_verified_sess": len(verified),
"verified_sess_ids": verified.get_session_ids(),
"n_retrieved_sess": len(retrieved),
"retrieved_sess_ids": retrieved.get_session_ids(),
"n_pool": len(pool),
"step_logs": step_logs,
})
# 4) Decide if continue or not
if len(retrieved) > top_k:
retrieved, _ = filter_out_by_embedding(retrieved, qid=qid, top_k=top_k)
answerable_response = self.is_answerable(question, question_date, retrieved, accumulated_evidence, model_info)
if answerable_response["is_answerable"]:
plan["answer"] = answerable_response["answer"]
return retrieved, attempt_record
if len(pool) == 0:
break
return retrieved, attempt_record
def _execute_strategy(self, pool, plan, question):
step_logs: List[str] = []
# Start from all chat items
if self.topic_filter and len(plan.get('topics', [])) > 0:
pool = pool.get_item_by_topics(plan['topics'])
retrieved = ChatHistory()
strategy = plan["strategy"]
for step in strategy:
method = step.get("method")
if method == "keyword":
kws = step.get("keywords", [])
matched = keyword_search(pool, kws)
step_logs.append(f"Method: keyword - matched {len(matched)}/{len(pool)} using {kws}")
if len(matched) > 0:
retrieved.merge_rel_sess(matched)
elif method == "embedding":
top_k = 50
matched = embedding_search(pool, qid, top_k=top_k)
step_logs.append(f"Method: embedding - top_k={top_k}, matched {len(matched)}/{len(pool)}")
if len(matched) > 0:
retrieved.merge_rel_sess(matched)
elif method == "time_filter":
if self.no_time_filter:
step_logs.append(f"Method: time_filter - skipped (--no_time_filter)")
continue
if 'time_range' not in step or len(step['time_range']) != 2:
continue
if len(retrieved) > 0:
retrieved = time_filter(retrieved, start_date=step['time_range'][0], end_date=step['time_range'][1])
else:
retrieved = time_filter(pool, start_date=step['time_range'][0], end_date=step['time_range'][1])
step_logs.append(f"Method: time_filter - kept {len(retrieved)}/{len(pool)} in {step['time_range'][0]}..{step['time_range'][1]}")
#if len(matched) > 0:
# retrieved.merge_rel_sess(matched)
else:
step_logs.append(f"unknown method: {method}")
if len(retrieved) > 100:
top_k = 100
retrieved = embedding_search(retrieved, qid, top_k=top_k)
step_logs.append(
f"too many sess ({len(pool)}) - embedding top_k={top_k} matched {len(retrieved)}/{len(pool)}"
)
return retrieved, step_logs
def merge_rel_sess(self, new_sessions: ChatHistory):
# Gather all current and new sessions in a dict keyed by session_id
all_sessions = {s["session_id"]: s for s in self.rel_sess.sessions}
# add if new
for s in new_sessions.sessions:
if s["session_id"] not in all_sessions:
all_sessions[s["session_id"]] = s
# Optional: sort sessions by timestamp for consistent ordering
#merged_sessions = list(all_sessions.values())
#merged_sessions.sort(key=lambda x: x["timestamp"])
# Reconstruct raw_data for new ChatHistory
merged_raw_data = {
"haystack_dates": [s["session_date"] for k, s in all_sessions.items()],
"haystack_session_ids": [s["session_id"] for k, s in all_sessions.items()],
"haystack_sessions": [s["session"] for k, s in all_sessions.items()],
}
self.rel_sess = ChatHistory(merged_raw_data)
def merge_evidence(self, new_evidence: list):
self.evidence = self.evidence + new_evidence
print(f"\t\t Updated evidence: {self.evidence}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--in_file', type=str, required=True)
parser.add_argument('--out_file', type=str, required=True)
parser.add_argument('--model_name', type=str, required=True)
parser.add_argument('--top_k', type=int, required=True)
parser.add_argument('--debug', action='store_true', default=False)
parser.add_argument('--vllm', action='store_true', default=False)
parser.add_argument('--tritonai', action='store_true', default=False,
help='Use OpenAI-compatible LiteLLM proxy for non-reading LLM calls (set TRITONAI_API_KEY)')
parser.add_argument('--nvidia', action='store_true', default=False,
help='Use NVIDIA inference API (set NV_API_KEY)')
parser.add_argument('--vllm_reading', action='store_true', default=False,
help='Use vLLM only for verification reading; all other LLM calls use the proprietary API')
parser.add_argument('--n_chunks', type=int, default=10,
help='Number of sessions per verification-reading LLM call (default 10)')
parser.add_argument('--max_loops', type=int, default=3,
help='Maximum retrieval/planning loops for agent mode (default 3)')
parser.add_argument('--mode', type=str, default="agent", choices=['agent', 'embed', 'keyword'])
parser.add_argument('--topic_filter', type=bool, default=True)
parser.add_argument('--user_profile', action=argparse.BooleanOptionalAction, default=True,
help='Include user profile in prompts (default: True, use --no-user_profile to disable)')
parser.add_argument('--no_semantic', action='store_true', default=False,
help='Skip semantic Stage 1; run episodic Stage 2 only on all haystack sessions')
parser.add_argument('--no_time_filter', action='store_true', default=False,
help='Disable time_filter steps in strategy execution (can reuse plan cache)')
# Two-stage memory arguments
parser.add_argument('--semantic_ret_cache', type=str, default=None,
help='Path to semantic-gte retrieval log (JSONL) for Stage 1')
parser.add_argument('--summary_file', type=str, default=None,
help='Path to all_session_summary.json for SemanticMemoryStore')
parser.add_argument('--facts_file', type=str, default=None,
help='Path to all_session_user_facts.json for SemanticMemoryStore')
parser.add_argument('--all_sessions_file', type=str, default=None,
help='Path to all_sessions.json for lazy episodic loading')
parser.add_argument('--no_save_cache', action='store_true', default=False,
help='Disable saving plan/reading caches to disk after the run')
parser.add_argument('--hier_v2', action='store_true', default=False,
help='Stage 1 produces candidates only: skip early-answer return, semantic keyword expansion, time_filter, and is_answerable shortcut')
parser.add_argument('--hier_union', action='store_true', default=False,
help='hier mode: union Stage-1 semantic candidates with flat-embedding top-K and run agent loop on the merged pool')
parser.add_argument('--hier_union_flat_k', type=int, default=20,
help='How many flat-embedding top-K IDs to union into the Stage-2 pool (default 20)')
parser.add_argument('--no_early_answer', action='store_true', default=False,
help='Disable Stage-1 is_answerable early-return shortcut; always proceed to Stage-2 agent loop')
parser.add_argument('--answer_prompt_v2', action='store_true', default=False,
help='Use the v2 answer prompt with explicit guidance for aggregation, temporal reasoning, knowledge updates, and absence cases.')
args = parser.parse_args()
# Rebind reading cache to include n_chunks so different chunk sizes get separate caches
veri_reading_log_file = os.environ['reading_cache'] + f'_nchunks{args.n_chunks}'
qid2rel_sess_ids = {}
if os.path.exists(veri_reading_log_file):
qid2rel_sess_ids = json.load(open(veri_reading_log_file))
print(f'Reading cache: {veri_reading_log_file} ({len(qid2rel_sess_ids)} cached entries)')
in_data = json.load(open(args.in_file))
top_k = args.top_k
out_file = args.out_file
model_info = model_zoo[args.model_name]
deployment_name, api_version = model_info
existings = set()
retrieval_metric_list = []
if os.path.exists(out_file):
for line in open(out_file):
obj = json.loads(line)
existings.add(obj['question_id'])
if 'retrieval_metric' in obj:
retrieval_metric_list.append(obj['retrieval_metric'])
out_f = open(out_file, 'a')
############# read meta files #####################
qid2profiles = {}
with open("metadata/generated_user_profile.json") as f:
qid2profiles = json.load(f)
sess2topic = {}
with open("metadata/sessions_with_topic.json") as f:
sess2topic = json.load(f)
# ----------------------------------------------------------------
# Two-stage memory stores (optional; activated by CLI args)
# ----------------------------------------------------------------
semantic_store = None
episodic_store = None
semantic_ret_dict = None
if args.summary_file and args.facts_file:
semantic_store = SemanticMemoryStore(args.summary_file, args.facts_file)
if args.all_sessions_file:
episodic_store = EpisodicMemoryStore(args.all_sessions_file)
if args.semantic_ret_cache:
print(f"Loading semantic retrieval cache from {args.semantic_ret_cache} ...")
sem_ret_data = [json.loads(line) for line in open(args.semantic_ret_cache)]
semantic_ret_dict = {x['question_id']: x for x in sem_ret_data}
print(f" Loaded {len(semantic_ret_dict)} entries.")
retrieval_metric_list = []
for di, entry in enumerate(in_data):
item_start_time = time.time()
qid, question, q_date = entry['question_id'], entry['question'], entry['question_date']
q_date = entry['question_date']
if qid in existings:
continue
haystack_sess_ids = entry['haystack_session_ids']
haystack_topics = [sess2topic.get(sid, {}).get('category', []) for sid in haystack_sess_ids]
date_lookup = dict(zip(haystack_sess_ids, entry['haystack_dates']))
topic_lookup = dict(zip(haystack_sess_ids, haystack_topics))
# Build ChatHistory: lazily from episodic store (two-stage) or from raw data (legacy)
if episodic_store is not None:
# Two-stage mode: start with full haystack loaded from episodic store
# (Stage 1 will narrow this down before Stage 2 runs)
raw_sessions = episodic_store.get_raw_sessions(
haystack_sess_ids, date_lookup, topic_lookup
)
chat_history = ChatHistory(sessions=raw_sessions)
else:
chat_history = ChatHistory({
"haystack_dates": entry['haystack_dates'],
"haystack_session_ids": entry['haystack_session_ids'],
"haystack_sessions": entry['haystack_sessions'],
"haystack_topics": haystack_topics,
})
topic_set = set()
for ht in haystack_topics:
topic_set.update(ht)
if args.user_profile:
# user profile
temp_qid = qid
if '_q_' in qid:
temp_qid = qid.split("_q_")[0]
user_profile = qid2profiles[temp_qid]
agent = RetrievalAgent(
chat_history,
list(topic_set),
user_profile=user_profile,
debug=args.debug,
vllm=args.vllm,
vllm_reading=args.vllm_reading,
tritonai=args.tritonai,
nvidia=args.nvidia,
n_chunks=args.n_chunks,
topic_filter=args.topic_filter,
no_time_filter=args.no_time_filter,
semantic_store=None if args.no_semantic else semantic_store,
episodic_store=episodic_store,
hier_v2=args.hier_v2,
hier_union=args.hier_union,
hier_union_flat_k=args.hier_union_flat_k,
no_early_answer=args.no_early_answer,
)
else:
agent = RetrievalAgent(
chat_history,
list(topic_set),
debug=args.debug,
vllm=args.vllm,
vllm_reading=args.vllm_reading,
tritonai=args.tritonai,
nvidia=args.nvidia,
n_chunks=args.n_chunks,
topic_filter=args.topic_filter,
no_time_filter=args.no_time_filter,
semantic_store=None if args.no_semantic else semantic_store,
episodic_store=episodic_store,
hier_v2=args.hier_v2,
hier_union=args.hier_union,
hier_union_flat_k=args.hier_union_flat_k,
no_early_answer=args.no_early_answer,
)
try:
if args.mode == 'embed':
final_sess = embedding_search(chat_history, qid, top_k=top_k)
attempt_record = [{"plan": {"answer": "none", "reason": "embedding retrieval only"}}]
elif args.mode == 'keyword':
keywords = generate_keywords(question, deployment_name, api_version,
debug=args.debug, vllm=args.vllm,
tritonai=args.tritonai, nvidia=args.nvidia)
final_sess = keyword_search(chat_history, keywords=keywords)
attempt_record = [{"plan": {"answer": "none", "reason": "keyword retrieval only"}}]
else: # agent
final_sess, attempt_record = agent.run(
qid, question, q_date, top_k, model_info,
max_loops=args.max_loops,
semantic_ret_dict=semantic_ret_dict,
haystack_sess_ids=haystack_sess_ids,
date_lookup=date_lookup,
topic_lookup=topic_lookup,
)
if len(attempt_record) == 1 and "answer" in attempt_record[0]["plan"] and not ("none" in attempt_record[0]["plan"]["answer"].lower()):
answer = attempt_record[0]["plan"]["answer"]
token_budget = agent.get_token_budget() if args.mode == 'agent' else {}
wall_time_sec = time.time() - item_start_time
print(json.dumps({"q_idx": di, 'question_id': qid, 'question': entry['question'],
'answer': answer, 'n_retrieved': len(final_sess),
'wall_time_sec': round(wall_time_sec, 3)}, indent=4), flush=True)
print(json.dumps({"q_idx": di, 'question_id': qid,
'hypothesis': answer,
"attempt_record": attempt_record,
"token_budget": token_budget,
"wall_time_sec": wall_time_sec}), file=out_f, flush=True)
else:
if len(final_sess) > top_k and retrieved_log_file is not None:
final_top_k_sess, _ = filter_out_by_embedding(final_sess, qid=qid, top_k=top_k)
retrieved_str = final_top_k_sess.to_prompt(granularity="session", _format="json")
else:
retrieved_str = final_sess.to_prompt(granularity="session", _format="json")
if args.answer_prompt_v2:
answer_prompt_template = (
"You are answering a question using a list of chat-session transcripts between the user and an assistant.\n"
"\n"
"How to answer:\n"
"1. Scan ALL retrieved sessions in chronological order. The SESSION DATE on each transcript is when that conversation occurred. The Current Date below is when the question was asked, not when events happened.\n"
"2. Identify every session containing a candidate fact. If sessions conflict, prefer the most RECENT session that addresses the same fact (knowledge update).\n"
"3. For aggregation questions ('how many', 'list all', 'between X and Y'), enumerate matches across ALL relevant sessions; do not stop at the first.\n"
"4. For temporal queries ('last Friday', 'two weeks ago'), resolve the relative date against the SESSION DATE of the session that uses that phrase, not the Current Date.\n"
"5. If the retrieved sessions do NOT contain the answer, reply exactly 'Insufficient information in retrieved sessions.' Do not fabricate.\n"
"6. Be terse: state the direct answer first, then one short sentence citing the session date(s) you relied on.\n"
"\n"
"Chat history sessions:\n"
"\n"
"{}\n"
"\n"
"Current Date: {}\n"
"Question: {}\n"
"Answer:"
)
else:
answer_prompt_template = "I will give you several chat history sessions between you and a user. Please answer the question given the information.\n\n\nChat history sessions:\n\n{}\n\nCurrent Date: {}\nQuestion: {}\nAnswer:"
answer_prompt = answer_prompt_template.format(retrieved_str, entry['question_date'], entry['question'])
completion = llm_call(
deployment_name,
api_version,
answer_prompt,
debug=args.debug,
vllm=args.vllm,
tritonai=args.tritonai,
nvidia=args.nvidia,
)
answer = (completion.choices[0].message.content or "").strip()
if args.mode == 'agent':
agent._track_usage('final_answer', completion)
token_budget = agent.get_token_budget() if args.mode == 'agent' else {}
retrieval_metric = {}
if len(final_sess) > 0 and retrieved_log_file is not None:
sess_sorted = embedding_search(final_sess, qid, top_k=20)
sess_id_sorted = sess_sorted.get_session_ids()
for topk in [5, 10, 20, 30]:
recall_any, recall_all = evaluate_retrieval(sess_id_sorted[:topk], entry['answer_session_ids'])
retrieval_metric.update({
'recall_any@{}'.format(topk): recall_any,
'recall_all@{}'.format(topk): recall_all
})
retrieval_metric_list.append(retrieval_metric)
print_average_metrics(retrieval_metric_list)
print(json.dumps({"q_idx": di, 'n_prompt_tok': completion.usage.prompt_tokens,
'n_completion_tok': completion.usage.completion_tokens,
'hypothesis': answer,
'wall_time_sec': round(time.time() - item_start_time, 3)}), flush=True)
print(json.dumps({"q_idx": di, 'question_id': qid,
'hypothesis': answer,
'n_prompt_tok': completion.usage.prompt_tokens,
'n_completion_tok': completion.usage.completion_tokens,
"attempt_record": attempt_record,
"retrieved_sess_ids": final_sess.get_session_ids(),
"retrieval_metric": retrieval_metric,
"token_budget": token_budget,
"wall_time_sec": time.time() - item_start_time}), file=out_f, flush=True)
except Exception as e:
print(f"[ERROR] q_idx={di} qid={qid} failed: {e}", flush=True)
continue
############# save cache ##########################
if not args.no_save_cache:
with open(plan_cache_file, "w") as fw:
json.dump(qid2plan, fw, indent=2)
with open(veri_reading_log_file, "w") as fw:
json.dump(qid2rel_sess_ids, fw, indent=2)