import argparse import os import re import json import time from json import JSONDecodeError from datetime import datetime, timedelta from typing import List, Dict, Any from openai import OpenAI try: from openai import AzureOpenAI from azure.identity import ChainedTokenCredential, AzureCliCredential, ManagedIdentityCredential, get_bearer_token_provider # Azure scope for the OAuth bearer-token provider; override per deployment. AZURE_OAUTH_SCOPE = os.environ.get("AZURE_OAUTH_SCOPE", "") if AZURE_OAUTH_SCOPE: credential = get_bearer_token_provider(ChainedTokenCredential( AzureCliCredential(), ManagedIdentityCredential(), ), AZURE_OAUTH_SCOPE) else: credential = None except ImportError: AzureOpenAI = None credential = None from model_zoo import model_zoo from memory import EpisodicMemoryStore, SemanticMemoryStore try: import tiktoken except ImportError: tiktoken = None try: from transformers import AutoTokenizer, PreTrainedTokenizerBase except ImportError: AutoTokenizer = None PreTrainedTokenizerBase = () from collections import defaultdict def get_hf_tokenizer_for_vllm(model_name: str): return AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True) # Azure OpenAI endpoint (set AZURE_OPENAI_ENDPOINT env var to your deployment URL). endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", "") # OpenAI-compatible LiteLLM proxy URL (set LITELLM_BASE_URL env var to your proxy). TRITONAI_BASE_URL = os.environ.get("LITELLM_BASE_URL", "") # reading cached files qid2plan = {} if 'plan_cache' in os.environ: plan_cache_file = os.environ['plan_cache'] if os.path.exists(plan_cache_file): qid2plan = json.load(open(plan_cache_file)) else: plan_cache_file = 'response_cache/qa/evolv_mem_v3_plan_cache_gpt5-1' veri_reading_log_file = os.environ['reading_cache'] qid2rel_sess_ids = {} if os.path.exists(veri_reading_log_file): qid2rel_sess_ids = json.load(open(veri_reading_log_file)) # Cache file for retrieval results to avoid re-running expensive retrieval operations. # Stores pre-computed search results for questions, including: # - Question metadata (id, type, text, answer, dates) # - Haystack information (session dates, content, IDs) # - Retrieved results with query, ranked items, and evaluation metrics # Format: JSONL file where each line contains a complete retrieval result for one question retrieved_log_file = None if 'ret_cache' in os.environ: retrieved_log_file = os.environ['ret_cache'] print("loading existing retrieved results ...") retrieved_data = [json.loads(line) for line in open(retrieved_log_file).readlines()] retrieved_data_dict = {x['question_id']: x for x in retrieved_data} valid_sess_set = set(json.load(open("dataset/all_sessions.json")).keys()) def parse_json(response_content): """Safely parse JSON content from a string response.""" candidates = [] if '```json' in response_content: start_idx = response_content.find('```json') + 7 end_idx = response_content.rfind('```') if end_idx > start_idx: candidates.append(response_content[start_idx:end_idx].strip()) # Also try brace-based extraction as fallback brace_start = response_content.find('{') brace_end = response_content.rfind('}') + 1 if brace_start >= 0 and brace_end > brace_start: candidates.append(response_content[brace_start:brace_end].strip()) for json_block in candidates: try: result = json.loads(json_block) return result except (JSONDecodeError, ValueError): continue print(f"[Warning] Failed to decode JSON from response (all strategies failed)") print(f"[Debug] Raw response content (truncated): {response_content[:500]}") return {} def _retryable_status(e): # Try common attributes first — return any extractable HTTP status code status = getattr(e, "status_code", None) or getattr(e, "http_status", None) if status is not None: return int(status) resp = getattr(e, "response", None) if resp is not None and getattr(resp, "status_code", None) is not None: return int(resp.status_code) # Fallback: infer from message text msg = str(e).lower() if "429" in msg or "rate limit" in msg: return 429 if "500" in msg or "internal server error" in msg: return 500 if "503" in msg or "API Configuration unavailable" in msg: return 503 return None MAX_CONTEXT_TOKENS = 272_000 #MAX_CONTEXT_TOKENS = 256_000 class CharacterEncoder: """Conservative fallback when tokenizer packages are unavailable.""" def encode(self, text, **kwargs): return list(text) def decode(self, toks): return "".join(toks) def _get_encoder(model_name: str): """ Return a token encoder for the given model name. Cached to avoid reloading HF tokenizers on every call. """ # Prefer explicit handling for Qwen models first # Note: Adjust the path below if you have the model downloaded locally if AutoTokenizer is not None and any(k in model_name for k in ["Qwen3", "Qwen/", "Qwen"]): try: return AutoTokenizer.from_pretrained( "Qwen/Qwen3-30B-A3B-Instruct-2507", trust_remote_code=True, use_fast=False ) except Exception as e: print(f"[WARN] Failed to load Qwen tokenizer: {e}. Falling back to tiktoken.") # For non-Qwen models, rely on tiktoken's mapping when possible if tiktoken is not None: try: return tiktoken.encoding_for_model(model_name) except Exception: # Generic safe fallback return tiktoken.get_encoding("cl100k_base") return CharacterEncoder() def _truncate_to_tokens(text, enc, max_tokens) -> str: """ Truncates text to the last `max_tokens`. Compatible with both tiktoken and Hugging Face AutoTokenizers. """ # Handle Hugging Face Tokenizers if PreTrainedTokenizerBase and isinstance(enc, PreTrainedTokenizerBase): # add_special_tokens=False is crucial here to avoid double counting # or inserting BOS/EOS in the middle of text during length checks toks = enc.encode(text, add_special_tokens=False) # Handle tiktoken else: toks = enc.encode(text, disallowed_special=()) if len(toks) <= max_tokens: return text # Keep the tail (usually the most relevant for instructions / recent context) toks = toks[-max_tokens:] return enc.decode(toks) def truncate_chat_prompt(tokenizer, messages, max_context, max_output, overhead=256): # Apply the model's chat template so token counting matches the server prompt_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) input_ids = tokenizer(prompt_text, add_special_tokens=False).input_ids budget = max_context - max_output - overhead if budget < 0: raise ValueError("max_output_tokens + overhead exceeds max_context_tokens") if len(input_ids) > budget: input_ids = input_ids[-budget:] # or keep the *start* depending on your needs prompt_text = tokenizer.decode(input_ids, skip_special_tokens=False) return prompt_text def llm_call(deployment_name: str, api_version: str, _prompt: str, max_context_tokens: int = MAX_CONTEXT_TOKENS, max_output_tokens: int = 1024, extra_overhead_tokens: int = 32, debug: bool = False, vllm: bool = False, tritonai: bool = False, nvidia: bool = False): if nvidia: client = OpenAI( api_key=os.getenv("NV_API_KEY"), base_url="https://inference-api.nvidia.com/v1", ) elif tritonai: client = OpenAI( api_key=os.getenv("TRITONAI_API_KEY"), base_url=TRITONAI_BASE_URL, ) max_context_tokens = 131_072 # DeepSeek R1 128K context # DeepSeek R1 uses thinking tokens before the answer; raise output budget if max_output_tokens < 4096: max_output_tokens = 4096 elif vllm: # Use local vLLM OpenAI-compatible server vllm_base_url = os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1") vllm_api_key = os.getenv("VLLM_API_KEY", "EMPTY") client = OpenAI( base_url=vllm_base_url, api_key=vllm_api_key, ) # Override deployment_name with the vLLM-served model name when set # (needed when main model is LiteLLM proxy but reading uses local vLLM) deployment_name = os.getenv("VLLM_MODEL_NAME", deployment_name) # vLLM Qwen3-30B-A3B-Instruct-2507 has 131,072-token context max_context_tokens = 131_072 elif debug: client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) else: client = AzureOpenAI( azure_endpoint=endpoint, azure_ad_token_provider=credential, api_version=api_version, ) enc = _get_encoder(deployment_name) # How many tokens we can spend on the input budget = max_context_tokens - max_output_tokens - extra_overhead_tokens if budget < 0: raise ValueError("max_output_tokens + overhead exceeds max_context_tokens") prompt_truncated = _truncate_to_tokens(_prompt, enc, budget) # Strip control characters and fix broken Unicode that break JSON serialization if nvidia or tritonai: prompt_truncated = prompt_truncated.encode('utf-8', errors='replace').decode('utf-8', errors='replace') prompt_truncated = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', prompt_truncated) # Verify it's valid JSON-serializable json.dumps(prompt_truncated) # OpenAI-compatible proxy requires at least one user message; Azure accepts system role msg_role = "user" if (tritonai or nvidia) else "system" kwargs = { 'model': deployment_name, 'messages':[ {"role": msg_role, "content": prompt_truncated} ] } while True: try: completion = client.chat.completions.create(**kwargs) break except Exception as e: from openai import APITimeoutError as _APITimeoutError if isinstance(e, _APITimeoutError): print(f"[WARN] APITimeoutError from LLM; sleeping 30s then retrying...", flush=True) time.sleep(30) continue st = _retryable_status(e) # 404 from LiteLLM proxy/Bedrock is intermittent (model temporarily unavailable) retryable = (429, 500, 503, 403) + ((404,) if tritonai else ()) if st in retryable: print( f"[WARN] q_idx={di} HTTP {st} from LLM; sleeping 60s then retrying...", flush=True ) time.sleep(60) continue # Non-retryable -> re-raise print('One exception captured', repr(e), flush=True) raise #answer = (completion.choices[0].message.content or "").strip() #return answer return completion def custom_to_iso8601(time_str): """ Convert '2023/04/10 (Mon) 23:07' to '2023-04-10T23:07:00' """ # Remove the weekday (e.g., "(Mon)") clean = time_str.split('(')[0].strip() + ' ' + time_str.split(')')[-1].strip() # Parse the cleaned string dt = datetime.strptime(clean, "%Y/%m/%d %H:%M") # Format as ISO 8601 return dt.isoformat() def evaluate_retrieval(recalled_docs, correct_docs, k=10): #recalled_docs = set(corpus_ids[idx] for idx in rankings[:k]) recall_any = float(any(doc in recalled_docs for doc in correct_docs)) recall_all = float(all(doc in recalled_docs for doc in correct_docs)) return recall_any, recall_all def print_average_metrics(retrieval_metric_list): metric_sums = defaultdict(float) metric_counts = defaultdict(int) for metric in retrieval_metric_list: for k, v in metric.items(): metric_sums[k] += v metric_counts[k] += 1 print("\t\t Average metrics:") for k in sorted(metric_sums): avg = metric_sums[k] / metric_counts[k] print(f"\t\t {k}: {avg:.4f}") # Load prompt template prompt_path = "prompts/agentic_retrieval_prompt.txt" with open(prompt_path, "r", encoding="utf-8") as f: stg_prompt = f.read() class ChatHistory: def __init__(self, data: Dict[str, Any] = None, sessions: List = None): assert not (data is not None and sessions is not None), "ChatHistory: Only one of data or sessions may be provided." if data is not None: # From raw data dict self.raw_data = data self.sessions = [] self._parse_sessions() elif sessions is not None: # From provided sessions list self.sessions = sessions self.messages = [] for sess in self.sessions: session_id = sess['session_id'] timestamp = sess['timestamp'] for turn_idx, msg in enumerate(sess['session']): entry = { "role": msg.get("role"), "content": msg.get("content"), "session_id": session_id, "turn_index": turn_idx, "timestamp": timestamp, "iso_datetime": timestamp.isoformat(), "has_answer": msg.get("has_answer", False) } self.messages.append(entry) else: self.sessions = [] self.messages = [] def get_contents(self, granularity='turn', _format='json') -> list: if granularity == "turn": if _format == "json": return [json.dumps(msg) for msg in self.messages] else: return [msg['content'] for msg in self.messages] else: # granularity == 'session' if _format == "json": return [json.dumps(session) for session in self.sessions] else: return [json.dumps({"role": session["role"], "content": session["content"]}) for session in self.sessions] def to_prompt(self, granularity='session', _format="json"): history_str = "" for session in self.sessions: sess_str = json.dumps([{"role": x["role"], "content": x["content"]} for x in session['session']]) history_str += f"Session Date: {session['session_date']}\nSession Content:\n{sess_str}\n" return history_str def get_session_ids(self): return [s['session_id'] for s in self.sessions] @staticmethod def _parse_date(date_str: str) -> datetime: # Convert '2023/04/10 (Mon) 17:50' to datetime # Remove weekday in parentheses date_part, time_part = date_str.split('(')[0].strip(), date_str.split(')')[-1].strip() dt = datetime.strptime(date_part + time_part, "%Y/%m/%d%H:%M") return dt def _parse_sessions(self): """ Flattens sessions into a list of messages, each with ISO 8601 date, session ID, and turn index """ self.messages = [] for date_str, session_id, session, topic in zip( self.raw_data['haystack_dates'], self.raw_data['haystack_session_ids'], self.raw_data['haystack_sessions'], self.raw_data['haystack_topics'] ): timestamp = self._parse_date(date_str) for turn_idx, msg in enumerate(session): entry = { "role": msg.get("role"), "content": msg.get("content"), "session_id": session_id, "turn_index": turn_idx, "timestamp": timestamp, "iso_datetime": timestamp.isoformat(), "session_date": date_str, "has_answer": msg.get("has_answer", False) } self.messages.append(entry) self.sessions.append({ "session_date": date_str, "timestamp": timestamp, "session_id": session_id, "session": session, "topic": topic, }) # Optionally, sort by time (ascending) #self.sessions.sort(key=lambda x: x["timestamp"]) #self.messages.sort(key=lambda x: x["timestamp"]) def __len__(self): return len(self.sessions) def __getitem__(self, idx) -> Dict[str, any]: return self.sessions[idx] # Return session dict def get_item_by_index(self, idx): if isinstance(idx, range) or isinstance(idx, list): max_idx = len(self.sessions) valid_indices = [i for i in idx if 0 <= i < max_idx] selected_sessions = [self.sessions[i] for i in valid_indices] return ChatHistory(sessions=selected_sessions) else: raise ValueError("Input must be a list or range of indices.") def get_item_by_session_ids(self, sess_set): if not isinstance(sess_set, set): sess_set = set(sess_set) new_sessions = [] for sess in self.sessions: if sess['session_id'] in sess_set: new_sessions.append(sess) return ChatHistory(sessions=new_sessions) def get_item_by_ranked_session(self, sess_id_sorted): new_sessions = [] for sess_id in sess_id_sorted: for sess in self.sessions: if sess['session_id'] in sess_id: new_sessions.append(sess) return ChatHistory(sessions=new_sessions) def get_item_by_topics(self, topics): new_sessions = [] new_sess_ids = set() for sess in self.sessions: for tp in sess['topic']: if tp in topics and sess['session_id'] not in new_sess_ids: new_sessions.append(sess) new_sess_ids.add(sess['session_id']) break return ChatHistory(sessions=new_sessions) def merge_rel_sess(self, new_sessions): # Gather all current and new sessions in a dict keyed by session_id all_sessions = {s["session_id"]: s for s in self.sessions} # add if new for s in new_sessions.sessions: if s["session_id"] not in all_sessions: all_sessions[s["session_id"]] = s # Reconstruct raw_data for new ChatHistory merged_raw_data = { "haystack_dates": [s["session_date"] for k, s in all_sessions.items()], "haystack_session_ids": [s["session_id"] for k, s in all_sessions.items()], "haystack_sessions": [s["session"] for k, s in all_sessions.items()], "haystack_topics": [s["topic"] for k, s in all_sessions.items()], } self.sessions = ChatHistory(merged_raw_data) def generate_keywords(question: str, deployment_name, api_version, debug=False, vllm=None, tritonai=False, nvidia=False) -> List[str]: # read prompt from `keyword_search_prompt.txt` file with open('prompts/keyword_search_prompt.txt') as f: prompt_template = f.read() prompt = prompt_template + question # Call the LLM to generate keywords completion = llm_call( deployment_name, api_version, prompt, debug=debug, vllm=vllm, tritonai=tritonai, nvidia=nvidia, ) response_content = (completion.choices[0].message.content or "").strip() result = parse_json(response_content) keywords = result["keywords"] if "keywords" in result else [] return keywords def keyword_search(chat_history: ChatHistory, keywords: list): print(f"\t\t** keyword search **: {keywords}") # Gather all messages that match start_time = time.time() matched_msgs = [ msg for msg in chat_history.messages if any(kw.lower() in (msg.get("content") or "").lower() for kw in keywords) ] end_time = time.time() execution_time = end_time - start_time if matched_msgs: new_sess_ids = set() for msg in matched_msgs: key = msg["session_id"] new_sess_ids.add(key) new_chat_history = chat_history.get_item_by_session_ids(new_sess_ids) else: new_chat_history = ChatHistory() return new_chat_history def is_turn_id(text): pattern = r'_\d+$' return bool(re.search(pattern, text)) def embedding_search(chat_history: ChatHistory, qid: str, top_k: int = 50, exclude_sess=None): print("\t\t** embedding based retrieval **") new_sess_ids = [] curr_all_sess = set(chat_history.get_session_ids()) if exclude_sess: curr_all_sess = curr_all_sess - set(exclude_sess) for item in retrieved_data_dict[qid]["retrieval_results"]["ranked_items"]: if item["corpus_id"] in valid_sess_set: sid = item["corpus_id"] else: tokens = item["corpus_id"].split("_") if "_turn" in item["corpus_id"]: sid = item["corpus_id"].split("_turn")[0] elif "_fact" in item["corpus_id"]: sid = item["corpus_id"].split("_fact")[0] elif "noans" in item["corpus_id"]: sid = item["corpus_id"].replace("noans", "answer") elif is_turn_id(item["corpus_id"]): sid = "_".join(tokens[:-1]) # remove turn index else: sid = item["corpus_id"] if sid not in valid_sess_set: print(item["corpus_id"], sid) assert sid in valid_sess_set if sid in curr_all_sess: new_sess_ids.append(sid) if len(new_sess_ids) == top_k: break new_chat_history = chat_history.get_item_by_ranked_session(new_sess_ids) return new_chat_history def filter_out_by_embedding(chat_history: ChatHistory, qid: str, top_k: int = 50): print("\t\t** [filter_out] embedding based retrieval - loading existing results ...") new_sess_ids = [] curr_all_sess = set(chat_history.get_session_ids()) for item in retrieved_data_dict[qid]["retrieval_results"]["ranked_items"]: if item["corpus_id"] in valid_sess_set: sid = item["corpus_id"] else: tokens = item["corpus_id"].split("_") if "_turn" in item["corpus_id"]: sid = item["corpus_id"].split("_turn")[0] elif "_fact" in item["corpus_id"]: sid = item["corpus_id"].split("_fact")[0] elif "noans" in item["corpus_id"]: sid = item["corpus_id"].replace("noans", "answer") elif is_turn_id(item["corpus_id"]): sid = "_".join(tokens[:-1]) # remove turn index else: sid = item["corpus_id"] if sid not in valid_sess_set: print(item["corpus_id"], sid) assert sid in valid_sess_set if sid in curr_all_sess: new_sess_ids.append(sid) if len(new_sess_ids) == top_k: break new_chat_history = chat_history.get_item_by_ranked_session(new_sess_ids) return new_chat_history, 0.0 def flat_embedding_top_k_ids(qid: str, haystack_sess_ids: List[str], top_k: int) -> List[str]: """ Pull the top_k session IDs from the global GTE retrieval cache (retrieved_data_dict), constrained to the question's haystack. Mirrors embedding_search() but operates on IDs only (no ChatHistory). Used by hier_union to widen the Stage-2 pool. """ haystack_set = set(haystack_sess_ids) ids: List[str] = [] for item in retrieved_data_dict[qid]["retrieval_results"]["ranked_items"]: cid = item["corpus_id"] if cid in valid_sess_set: sid = cid elif "_turn" in cid: sid = cid.split("_turn")[0] elif "_fact" in cid: sid = cid.split("_fact")[0] elif "noans" in cid: sid = cid.replace("noans", "answer") elif is_turn_id(cid): sid = "_".join(cid.split("_")[:-1]) else: sid = cid if sid in haystack_set and sid not in ids: ids.append(sid) if len(ids) == top_k: break return ids def semantic_embedding_search( qid: str, haystack_sess_ids: List[str], semantic_retrieved_dict: dict, top_k: int = 50, ) -> List[str]: """ Like embedding_search() but reads from the pre-computed semantic-gte retrieval cache. Returns an ordered list of up to top_k session IDs from the haystack. """ print("\t\t** semantic embedding retrieval **") haystack_set = set(haystack_sess_ids) ranked_ids: List[str] = [] for item in semantic_retrieved_dict[qid]["retrieval_results"]["ranked_items"]: sid = item["corpus_id"] # already session-level (no turn suffix) if sid in haystack_set and sid not in ranked_ids: ranked_ids.append(sid) if len(ranked_ids) == top_k: break return ranked_ids def time_filter(chat_history: ChatHistory, start_date: str, end_date: str) -> ChatHistory: # Returns all messages with timestamp in the ISO date range [start_date, end_date] (inclusive). start_time = time.time() try: start = datetime.fromisoformat(start_date) end = datetime.fromisoformat(end_date) filtered_msgs = [msg for msg in chat_history.messages if start.date() <= msg["timestamp"].date() <= end.date()] except Exception as e: print("Converting date error: ", e) filtered_msgs = [] end_time = time.time() execution_time = end_time - start_time if filtered_msgs: new_sess_ids = set() for msg in filtered_msgs: key = msg["session_id"] new_sess_ids.add(key) new_chat_history = chat_history.get_item_by_session_ids(new_sess_ids) else: new_chat_history = ChatHistory() return new_chat_history class RetrievalAgent: def __init__( self, history: List[Dict], topics: List[str], user_profile: str = None, debug: bool = False, vllm: bool = False, vllm_reading: bool = False, tritonai: bool = False, nvidia: bool = False, n_chunks: int = 10, topic_filter: bool = True, no_time_filter: bool = False, semantic_store: SemanticMemoryStore = None, episodic_store: EpisodicMemoryStore = None, hier_v2: bool = False, hier_union: bool = False, hier_union_flat_k: int = 20, no_early_answer: bool = False, ): self.chat_history = history self.user_profile = user_profile self.topics = topics self.rel_sess = ChatHistory() self.evidence = [] self.debug = debug self.vllm = vllm self.vllm_reading = vllm_reading # use vLLM only for _read_and_verify self.tritonai = tritonai # use LiteLLM proxy for non-reading LLM calls self.nvidia = nvidia # use NVIDIA inference API self.no_time_filter = no_time_filter # skip time_filter steps in strategy self.n_chunks = n_chunks self.topic_filter = topic_filter self.semantic_store = semantic_store self.episodic_store = episodic_store self.hier_v2 = hier_v2 self.hier_union = hier_union self.hier_union_flat_k = hier_union_flat_k self.no_early_answer = no_early_answer self.token_budget = { 'planning': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0}, 'verification_reading': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0}, 'is_answerable': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0}, 'final_answer': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0}, } if self.user_profile: with open('prompts/read_and_extract_prompt.txt') as f: self.read_prompt_template = f.read() else: # ablation: wo_profile with open('prompts/agentic_retrieval_prompt_wo_profile.txt') as f: self.read_prompt_template = f.read() def _track_usage(self, component: str, completion) -> None: """Accumulate prompt/completion token counts for a named component.""" usage = getattr(completion, 'usage', None) if usage is None: return self.token_budget[component]['prompt_tokens'] += getattr(usage, 'prompt_tokens', 0) or 0 self.token_budget[component]['completion_tokens'] += getattr(usage, 'completion_tokens', 0) or 0 self.token_budget[component]['n_calls'] += 1 def get_token_budget(self) -> dict: """Return token_budget with an added 'total' entry.""" total = {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0} for v in self.token_budget.values(): total['prompt_tokens'] += v['prompt_tokens'] total['completion_tokens'] += v['completion_tokens'] total['n_calls'] += v['n_calls'] return {**self.token_budget, 'total': total} def is_answerable(self, question: str, question_date: str, retrieved_sess, evidence, model_info, context_str: str = None) -> bool: context = "" for k, v in evidence.items(): # key: "profile", "tags", "chat_clues" for e in v: context += f"{e}\n" # context_str overrides retrieved_sess.to_prompt() (used in Stage 1 with semantic context) sess_str = context_str if context_str is not None else retrieved_sess.to_prompt() # Include user profile in the prompt when available profile_section = "" if self.user_profile: profile_section = f"\nUser Profile:\n{self.user_profile}\n" ia_prompt_prefix = f""" You are a decision-making agent tasked with determining when sufficient information has been gathered to answer a user's question. Your Task: Analyze the provided question, current date, available memory context, and available evidence to make a binary decision: Answerable or Not answerable. If the information is not sufficient, explain what specific information is needed to provide hints for the next retrieval stage. Question: {question} Current Date: {question_date} {profile_section} Memory Context: """ output_str = """ Output (always JSON — choose fields per the rules above) Case 1 — Answerable: { "is_answerable": true, "answer": "" } Case 2 — Not answerable: { "is_answerable": false, "info_needed": ["", ""] } """ deployment_name, api_version = model_info # ------------------------------------------------------------------ # Token-based truncation: keep the *end* of sess_str under budget # ------------------------------------------------------------------ enc = _get_encoder(deployment_name) # Model/context limits if self.vllm or self.tritonai: # Large context for vLLM / LiteLLM proxy models model_max_ctx = 131_072 else: model_max_ctx = MAX_CONTEXT_TOKENS max_output_tokens = 1024 extra_overhead_tokens = 32 # Total budget available for input tokens budget = model_max_ctx - max_output_tokens - extra_overhead_tokens if budget <= 0: raise ValueError( f"max_output_tokens ({max_output_tokens}) + overhead " f"({extra_overhead_tokens}) exceeds model_max_ctx ({model_max_ctx})" ) # Token lengths of static pieces prefix_tokens = enc.encode(ia_prompt_prefix, disallowed_special=()) output_tokens = enc.encode(output_str, disallowed_special=()) sess_tokens = enc.encode(sess_str, disallowed_special=()) # Budget for sess_str tokens available_for_sess = budget - len(prefix_tokens) - len(output_tokens) if available_for_sess <= 0: # No room for history at all; drop it truncated_sess_str = "" else: if len(sess_tokens) > available_for_sess: # Keep the *last* available_for_sess tokens (drop oldest history) truncated_sess_tokens = sess_tokens[-available_for_sess:] truncated_sess_str = enc.decode(truncated_sess_tokens) else: truncated_sess_str = sess_str ia_prompt = ia_prompt_prefix + truncated_sess_str + "\n" # ------------------------------------------------------------------ # Call the LLM with the already-truncated prompt # ------------------------------------------------------------------ completion = llm_call( deployment_name, api_version, ia_prompt + output_str, max_context_tokens=model_max_ctx, # matches what we used for budgeting max_output_tokens=max_output_tokens, extra_overhead_tokens=extra_overhead_tokens, debug=self.debug, vllm=self.vllm, tritonai=self.tritonai, nvidia=self.nvidia, ) self._track_usage('is_answerable', completion) response_content = (completion.choices[0].message.content or "").strip() print(f"\t\t[Agent] is_answerable: {response_content}") result = parse_json(response_content) if not result: print("[Warning] Empty or invalid JSON in is_answerable() response.") return {"is_answerable": False, "info_needed": ["Parsing failed"]} return result def _read_and_verify(self, question: str, question_date: str, evidence: ChatHistory, n_chunks=10) -> ChatHistory: # Read evidence and and select relevant_indices = [] evidence_list = [] max_idx = len(evidence) for j in range(0, len(evidence), n_chunks): chunk_range = range(j, j+n_chunks) valid_indices = [i for i in chunk_range if 0 <= i < max_idx] cur_chunk = evidence.get_item_by_index(valid_indices) cur_chunk_sess = [[{"role": m["role"], "content": m["content"]} for m in sess['session']] for sess in cur_chunk.sessions] cur_chunk_sess_date = [sess['session_date'] for sess in cur_chunk.sessions] sess_input_str = "\n".join([ f"### Session Index: {i}\n### Session Date: {sess_date}\n\n{json.dumps(sess)}\n" for i, (sess, sess_date) in enumerate(zip(cur_chunk_sess, cur_chunk_sess_date)) ]) _prompt = self.read_prompt_template + f"## Question: {question}\n## Question Date: {question_date}\n## Session list:\n\n{sess_input_str}\nNow, identify **only the sessions strictly necessary to answer the question**." completion = llm_call( deployment_name, api_version, _prompt, debug=self.debug, vllm=self.vllm or self.vllm_reading, nvidia=self.nvidia, ) self._track_usage('verification_reading', completion) response_content = (completion.choices[0].message.content or "").strip() print(f"\t\t {valid_indices[0]}~{valid_indices[-1]}: response: {response_content.replace(chr(10), '')}") try: start_idx = response_content.rfind('{') end_idx = response_content.rfind('}') + 1 json_block = response_content[start_idx:end_idx] result = json.loads(json_block) if "index" in result and result['index'] and 'evidence' and result: relevant_indices.extend([j + idx for idx in result['index']]) evidence_list.extend(result['evidence']) except Exception as e: print(f"Error parsing LLM response: {e}") if relevant_indices: return evidence.get_item_by_index(relevant_indices), evidence_list else: return ChatHistory(), [] def _read_and_verify_with_cache(self, qid:str, pool): relevant_sess_ids = [] for sess in pool.sessions: sess_id = sess['session_id'] if sess_id in qid2rel_sess_ids[qid]: relevant_sess_ids.append(sess_id) if len(relevant_sess_ids) > 0: return pool.get_item_by_session_ids(relevant_sess_ids), [] else: return ChatHistory(), [] def _plan(self, query: str, query_date: str, attempt_record: list, model_info) -> str: if self.user_profile: template = """ ### User profile: {user_profile} ### Chat history topics: {topics} ### User query: {query} ### User query date: {query_date} ### Previous attempts: {strategies_info} """ else: template = """ ### Chat history topics: {topics} ### User query: {query} ### User query date: {query_date} ### Previous attempts: {strategies_info} """ if attempt_record: strategies_info = "" for loop_num, entry in enumerate(attempt_record): strategies_info += f"\nloop_iteration: {loop_num+1}\n" strategies_info += "\n".join(entry.get('step_logs', [])) if 'n_retrieved_sess' in entry and 'evidence' in entry: strategies_info += f"Retrieved {entry['n_retrieved_sess']} docs, observed_evidence: {entry['evidence']}" if entry['n_retrieved_sess'] == 0: strategies_info += f"Additional Instruction: Re-try without filter methods if the previous paln includes topics or time-filtering\n" else: strategies_info = "(No previous attempt exists)" if self.user_profile: prompt_filled = stg_prompt + template.format( user_profile=self.user_profile, topics=",".join(self.topics), query=query, query_date=query_date, strategies_info=strategies_info) else: prompt_filled = stg_prompt + template.format( topics=",".join(self.topics), query=query, query_date=query_date, strategies_info=strategies_info) deployment_name, api_version = model_info completion = llm_call( deployment_name, api_version, prompt_filled, debug=self.debug, vllm=self.vllm, tritonai=self.tritonai, nvidia=self.nvidia, ) self._track_usage('planning', completion) response_content = (completion.choices[0].message.content or "").strip() _plan = parse_json(response_content) if not _plan: print("[Warning] Failed to parse plan JSON — retrying once.") completion = llm_call( deployment_name, api_version, prompt_filled, debug=self.debug, vllm=self.vllm, tritonai=self.tritonai, nvidia=self.nvidia, ) self._track_usage('planning', completion) response_content = (completion.choices[0].message.content or "").strip() _plan = parse_json(response_content) if not _plan: print("[Warning] Failed to parse plan JSON after retry — returning fallback plan.") _plan = {"answer": "none", "reason": "invalid JSON response", "topics": [], "strategy": []} return _plan def _run_stage1( self, qid: str, question: str, question_date: str, top_k: int, model_info, haystack_sess_ids: List[str], date_lookup: Dict[str, str], semantic_ret_dict: dict, ) -> dict: """ Stage 1: retrieve and evaluate using semantic memory only (summaries + facts). Returns a dict with: is_answerable : bool answer : str | None (set when is_answerable is True) candidate_ids : list[str] (top-K session IDs from semantic retrieval) attempt_record : list """ print(f"\t[Stage 1] Semantic memory retrieval for qid={qid}") # --- 1a. Semantic embedding search --- candidate_ids = semantic_embedding_search( qid, haystack_sess_ids, semantic_ret_dict, top_k=top_k ) # hier_v2: skip plan / keyword / time_filter / is_answerable. Stage 1 is candidate-only. if self.hier_v2: print(f"\t[Stage 1 hier_v2] embedding-only candidates: {len(candidate_ids)}") return { "is_answerable": False, "answer": None, "candidate_ids": candidate_ids[:top_k], "attempt_record": [{ "stage": "semantic_v2", "plan": {}, "n_candidates": len(candidate_ids), "candidate_ids": candidate_ids[:top_k], }], } # --- 1b. Plan for keywords / time filter (reuse existing _plan) --- plan = self._plan(question, question_date, [], model_info) print(json.dumps(plan, indent=4), flush=True) # If the planner already has a direct answer, return it if "answer" in plan and plan["answer"].lower() != "none": return { "is_answerable": True, "answer": plan["answer"], "candidate_ids": candidate_ids, "attempt_record": [{"plan": plan, "stage": "semantic"}], } # --- 1c. Semantic keyword search --- keyword_ids: List[str] = [] for step in plan.get("strategy", []): if step.get("method") == "keyword": kws = step.get("keywords", []) matched = self.semantic_store.keyword_search(kws, haystack_sess_ids) print(f"\t\t** semantic keyword search **: {kws} -> {len(matched)} matches") keyword_ids.extend(sid for sid in matched if sid not in keyword_ids) # --- 1d. Time filter on candidate set --- for step in plan.get("strategy", []): if self.no_time_filter: break # skip all time_filter steps if step.get("method") == "time_filter": if "time_range" not in step or len(step["time_range"]) != 2: continue start_str, end_str = step["time_range"] from datetime import datetime try: start_dt = datetime.fromisoformat(start_str) end_dt = datetime.fromisoformat(end_str) candidate_ids = [ sid for sid in candidate_ids if sid in date_lookup and start_dt.date() <= EpisodicMemoryStore._parse_date(date_lookup[sid]).date() <= end_dt.date() ] keyword_ids = [ sid for sid in keyword_ids if sid in date_lookup and start_dt.date() <= EpisodicMemoryStore._parse_date(date_lookup[sid]).date() <= end_dt.date() ] print(f"\t\t** semantic time_filter **: {start_str}..{end_str} -> " f"{len(candidate_ids)} embed, {len(keyword_ids)} keyword") except Exception as e: print(f"\t\t[WARN] time_filter parse error: {e}") # --- 1e. Merge candidates (keyword union with embedding, preserve rank) --- all_candidate_ids: List[str] = list(candidate_ids) for sid in keyword_ids: if sid not in all_candidate_ids: all_candidate_ids.append(sid) # Cap to top_k all_candidate_ids = all_candidate_ids[:top_k] if not all_candidate_ids: print("\t[Stage 1] No candidates found in semantic memory.") return { "is_answerable": False, "answer": None, "candidate_ids": [], "attempt_record": [{"plan": plan, "stage": "semantic", "n_candidates": 0}], } # --- 1f. Build semantic context string --- semantic_context_str = self.semantic_store.to_prompt(all_candidate_ids, date_lookup) print(f"\t[Stage 1] Built semantic context for {len(all_candidate_ids)} sessions " f"({len(semantic_context_str)} chars)") # --- 1g. is_answerable check on semantic context --- accumulated_evidence = {"profile": [], "chat_clues": []} answerable_response = self.is_answerable( question, question_date, retrieved_sess=None, evidence=accumulated_evidence, model_info=model_info, context_str=semantic_context_str, ) print(f"\t[Stage 1] is_answerable: {answerable_response}") attempt_record = [{ "stage": "semantic", "plan": plan, "n_candidates": len(all_candidate_ids), "candidate_ids": all_candidate_ids, "is_answerable": answerable_response.get("is_answerable", False), }] if answerable_response.get("is_answerable"): return { "is_answerable": True, "answer": answerable_response.get("answer"), "candidate_ids": all_candidate_ids, "attempt_record": attempt_record, } print(f"\t[Stage 1] Not answerable from semantic memory. " f"Info needed: {answerable_response.get('info_needed', [])}") return { "is_answerable": False, "answer": None, "candidate_ids": all_candidate_ids, "attempt_record": attempt_record, } def run(self, qid:str, question: str, question_date: str, top_k: int, model_info, max_loops=3, semantic_ret_dict: dict = None, haystack_sess_ids: List[str] = None, date_lookup: Dict[str, str] = None, topic_lookup: Dict[str, List[str]] = None): accumulated_evidence = {"profile": [], "chat_clues": []} attempt_record = [] loop_num = 0 # ---------------------------------------------------------------- # Stage 1: Semantic memory — only runs when stores are provided # ---------------------------------------------------------------- stage1_candidate_ids: List[str] = [] if (self.semantic_store is not None and semantic_ret_dict is not None and haystack_sess_ids is not None): stage1_result = self._run_stage1( qid, question, question_date, top_k, model_info, haystack_sess_ids, date_lookup or {}, semantic_ret_dict, ) attempt_record.extend(stage1_result["attempt_record"]) stage1_candidate_ids = stage1_result["candidate_ids"] if stage1_result["is_answerable"] and not self.no_early_answer: print(f"\t[Stage 1] Answered from semantic memory.") # Wrap answer in attempt_record format expected by caller if "answer" in stage1_result and stage1_result["answer"]: attempt_record[0]["plan"] = { **attempt_record[0].get("plan", {}), "answer": stage1_result["answer"], } return ChatHistory(), attempt_record if stage1_result["is_answerable"] and self.no_early_answer: print(f"\t[Stage 1] is_answerable=True but --no_early_answer set; proceeding to Stage-2.") # hier_union: widen Stage-2 pool with flat-embedding top-K from the global GTE cache. # This makes hier a strict superset of flat by construction; targets the recall gap # (semantic-over-summary embeddings rank worse than full-session embeddings). if self.hier_union and qid in retrieved_data_dict: flat_ids = flat_embedding_top_k_ids(qid, haystack_sess_ids, self.hier_union_flat_k) before = len(stage1_candidate_ids) for sid in flat_ids: if sid not in stage1_candidate_ids: stage1_candidate_ids.append(sid) print(f"\t[hier_union] semantic_top_k={before} + flat_top_{self.hier_union_flat_k}={len(flat_ids)} -> union={len(stage1_candidate_ids)}") attempt_record.append({ "stage": "hier_union", "plan": {}, "n_semantic": before, "n_flat": len(flat_ids), "n_union": len(stage1_candidate_ids), }) # Stage 2: load episodic sessions only for top-K candidates if self.episodic_store is not None and stage1_candidate_ids: print(f"\t[Stage 2] Loading episodic memory for " f"{len(stage1_candidate_ids)} candidate sessions.") raw_sessions = self.episodic_store.get_raw_sessions( stage1_candidate_ids, date_lookup or {}, topic_lookup ) self.chat_history = ChatHistory(sessions=raw_sessions) print(f"\t[Stage 2] Loaded {len(self.chat_history)} sessions into episodic pool.") # hier_v2: skip the agent loop. Use raw turns of the candidate sessions directly. # Rationale: the agent loop's verification can reject all of a 20-session pool, # leaving empty retrieved. Strong models answer better from raw turns of K=20 # semantic-selected sessions than from over-aggressive verification. if self.hier_v2: print(f"\t[hier_v2] bypassing agent loop; returning {len(self.chat_history)} candidate sessions as retrieved") return self.chat_history, attempt_record pool = self.chat_history retrieved = ChatHistory() while loop_num < max_loops: loop_num += 1 if loop_num == 1 and qid in qid2plan: plan = qid2plan[qid] else: plan = self._plan(question, question_date, attempt_record, model_info) qid2plan[qid] = plan print(json.dumps(plan, indent=4), flush=True) if "answer" in plan and not ("none" in plan["answer"].lower()): print(f"{qid}\t{question}\t{plan['answer']}", flush=True) return ChatHistory(), [{"plan": plan}] # 2) Execute the plan -> retrieve candidates try: candidates, step_logs = self._execute_strategy(pool, plan, question) except Exception as e: print(f"[Error] Failed during _execute_strategy: {e}") candidates, step_logs = ChatHistory(), [f"Execution failed: {e}"] if len(candidates) == 0: attempt_record.append({ "loop_iteration": loop_num, "plan": plan, "evidence": accumulated_evidence, "n_candidates_sess": len(candidates), "n_verified_sess": 0, "n_pool": len(pool), "step_logs": step_logs, }) continue else: remaining = set(pool.get_session_ids()) - set(candidates.get_session_ids()) pool = self.chat_history.get_item_by_session_ids(remaining) # 3) Verification Reading if qid in qid2rel_sess_ids: verified, evidence_list = self._read_and_verify_with_cache(qid, candidates) else: verified, evidence_list = self._read_and_verify(question, question_date, candidates, n_chunks=self.n_chunks) qid2rel_sess_ids[qid] = verified.get_session_ids() if len(verified) == 0: attempt_record.append({ "loop_iteration": loop_num, "plan": plan, "evidence": accumulated_evidence, "n_candidates_sess": len(candidates), "candidates_sess_ids": candidates.get_session_ids(), "n_verified_sess": len(verified), "verified_sess_ids": [], "n_pool": len(pool), "step_logs": step_logs, }) continue else: retrieved.merge_rel_sess(verified) for ev in evidence_list: if ev not in accumulated_evidence['chat_clues']: accumulated_evidence['chat_clues'].append(ev) attempt_record.append({ "loop_iteration": loop_num, "plan": plan, "evidence": accumulated_evidence, "n_candidates_sess": len(candidates), "candidates_sess_ids": candidates.get_session_ids(), "n_verified_sess": len(verified), "verified_sess_ids": verified.get_session_ids(), "n_retrieved_sess": len(retrieved), "retrieved_sess_ids": retrieved.get_session_ids(), "n_pool": len(pool), "step_logs": step_logs, }) # 4) Decide if continue or not if len(retrieved) > top_k: retrieved, _ = filter_out_by_embedding(retrieved, qid=qid, top_k=top_k) answerable_response = self.is_answerable(question, question_date, retrieved, accumulated_evidence, model_info) if answerable_response["is_answerable"]: plan["answer"] = answerable_response["answer"] return retrieved, attempt_record if len(pool) == 0: break return retrieved, attempt_record def _execute_strategy(self, pool, plan, question): step_logs: List[str] = [] # Start from all chat items if self.topic_filter and len(plan.get('topics', [])) > 0: pool = pool.get_item_by_topics(plan['topics']) retrieved = ChatHistory() strategy = plan["strategy"] for step in strategy: method = step.get("method") if method == "keyword": kws = step.get("keywords", []) matched = keyword_search(pool, kws) step_logs.append(f"Method: keyword - matched {len(matched)}/{len(pool)} using {kws}") if len(matched) > 0: retrieved.merge_rel_sess(matched) elif method == "embedding": top_k = 50 matched = embedding_search(pool, qid, top_k=top_k) step_logs.append(f"Method: embedding - top_k={top_k}, matched {len(matched)}/{len(pool)}") if len(matched) > 0: retrieved.merge_rel_sess(matched) elif method == "time_filter": if self.no_time_filter: step_logs.append(f"Method: time_filter - skipped (--no_time_filter)") continue if 'time_range' not in step or len(step['time_range']) != 2: continue if len(retrieved) > 0: retrieved = time_filter(retrieved, start_date=step['time_range'][0], end_date=step['time_range'][1]) else: retrieved = time_filter(pool, start_date=step['time_range'][0], end_date=step['time_range'][1]) step_logs.append(f"Method: time_filter - kept {len(retrieved)}/{len(pool)} in {step['time_range'][0]}..{step['time_range'][1]}") #if len(matched) > 0: # retrieved.merge_rel_sess(matched) else: step_logs.append(f"unknown method: {method}") if len(retrieved) > 100: top_k = 100 retrieved = embedding_search(retrieved, qid, top_k=top_k) step_logs.append( f"too many sess ({len(pool)}) - embedding top_k={top_k} matched {len(retrieved)}/{len(pool)}" ) return retrieved, step_logs def merge_rel_sess(self, new_sessions: ChatHistory): # Gather all current and new sessions in a dict keyed by session_id all_sessions = {s["session_id"]: s for s in self.rel_sess.sessions} # add if new for s in new_sessions.sessions: if s["session_id"] not in all_sessions: all_sessions[s["session_id"]] = s # Optional: sort sessions by timestamp for consistent ordering #merged_sessions = list(all_sessions.values()) #merged_sessions.sort(key=lambda x: x["timestamp"]) # Reconstruct raw_data for new ChatHistory merged_raw_data = { "haystack_dates": [s["session_date"] for k, s in all_sessions.items()], "haystack_session_ids": [s["session_id"] for k, s in all_sessions.items()], "haystack_sessions": [s["session"] for k, s in all_sessions.items()], } self.rel_sess = ChatHistory(merged_raw_data) def merge_evidence(self, new_evidence: list): self.evidence = self.evidence + new_evidence print(f"\t\t Updated evidence: {self.evidence}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--in_file', type=str, required=True) parser.add_argument('--out_file', type=str, required=True) parser.add_argument('--model_name', type=str, required=True) parser.add_argument('--top_k', type=int, required=True) parser.add_argument('--debug', action='store_true', default=False) parser.add_argument('--vllm', action='store_true', default=False) parser.add_argument('--tritonai', action='store_true', default=False, help='Use OpenAI-compatible LiteLLM proxy for non-reading LLM calls (set TRITONAI_API_KEY)') parser.add_argument('--nvidia', action='store_true', default=False, help='Use NVIDIA inference API (set NV_API_KEY)') parser.add_argument('--vllm_reading', action='store_true', default=False, help='Use vLLM only for verification reading; all other LLM calls use the proprietary API') parser.add_argument('--n_chunks', type=int, default=10, help='Number of sessions per verification-reading LLM call (default 10)') parser.add_argument('--max_loops', type=int, default=3, help='Maximum retrieval/planning loops for agent mode (default 3)') parser.add_argument('--mode', type=str, default="agent", choices=['agent', 'embed', 'keyword']) parser.add_argument('--topic_filter', type=bool, default=True) parser.add_argument('--user_profile', action=argparse.BooleanOptionalAction, default=True, help='Include user profile in prompts (default: True, use --no-user_profile to disable)') parser.add_argument('--no_semantic', action='store_true', default=False, help='Skip semantic Stage 1; run episodic Stage 2 only on all haystack sessions') parser.add_argument('--no_time_filter', action='store_true', default=False, help='Disable time_filter steps in strategy execution (can reuse plan cache)') # Two-stage memory arguments parser.add_argument('--semantic_ret_cache', type=str, default=None, help='Path to semantic-gte retrieval log (JSONL) for Stage 1') parser.add_argument('--summary_file', type=str, default=None, help='Path to all_session_summary.json for SemanticMemoryStore') parser.add_argument('--facts_file', type=str, default=None, help='Path to all_session_user_facts.json for SemanticMemoryStore') parser.add_argument('--all_sessions_file', type=str, default=None, help='Path to all_sessions.json for lazy episodic loading') parser.add_argument('--no_save_cache', action='store_true', default=False, help='Disable saving plan/reading caches to disk after the run') parser.add_argument('--hier_v2', action='store_true', default=False, help='Stage 1 produces candidates only: skip early-answer return, semantic keyword expansion, time_filter, and is_answerable shortcut') parser.add_argument('--hier_union', action='store_true', default=False, help='hier mode: union Stage-1 semantic candidates with flat-embedding top-K and run agent loop on the merged pool') parser.add_argument('--hier_union_flat_k', type=int, default=20, help='How many flat-embedding top-K IDs to union into the Stage-2 pool (default 20)') parser.add_argument('--no_early_answer', action='store_true', default=False, help='Disable Stage-1 is_answerable early-return shortcut; always proceed to Stage-2 agent loop') parser.add_argument('--answer_prompt_v2', action='store_true', default=False, help='Use the v2 answer prompt with explicit guidance for aggregation, temporal reasoning, knowledge updates, and absence cases.') args = parser.parse_args() # Rebind reading cache to include n_chunks so different chunk sizes get separate caches veri_reading_log_file = os.environ['reading_cache'] + f'_nchunks{args.n_chunks}' qid2rel_sess_ids = {} if os.path.exists(veri_reading_log_file): qid2rel_sess_ids = json.load(open(veri_reading_log_file)) print(f'Reading cache: {veri_reading_log_file} ({len(qid2rel_sess_ids)} cached entries)') in_data = json.load(open(args.in_file)) top_k = args.top_k out_file = args.out_file model_info = model_zoo[args.model_name] deployment_name, api_version = model_info existings = set() retrieval_metric_list = [] if os.path.exists(out_file): for line in open(out_file): obj = json.loads(line) existings.add(obj['question_id']) if 'retrieval_metric' in obj: retrieval_metric_list.append(obj['retrieval_metric']) out_f = open(out_file, 'a') ############# read meta files ##################### qid2profiles = {} with open("metadata/generated_user_profile.json") as f: qid2profiles = json.load(f) sess2topic = {} with open("metadata/sessions_with_topic.json") as f: sess2topic = json.load(f) # ---------------------------------------------------------------- # Two-stage memory stores (optional; activated by CLI args) # ---------------------------------------------------------------- semantic_store = None episodic_store = None semantic_ret_dict = None if args.summary_file and args.facts_file: semantic_store = SemanticMemoryStore(args.summary_file, args.facts_file) if args.all_sessions_file: episodic_store = EpisodicMemoryStore(args.all_sessions_file) if args.semantic_ret_cache: print(f"Loading semantic retrieval cache from {args.semantic_ret_cache} ...") sem_ret_data = [json.loads(line) for line in open(args.semantic_ret_cache)] semantic_ret_dict = {x['question_id']: x for x in sem_ret_data} print(f" Loaded {len(semantic_ret_dict)} entries.") retrieval_metric_list = [] for di, entry in enumerate(in_data): item_start_time = time.time() qid, question, q_date = entry['question_id'], entry['question'], entry['question_date'] q_date = entry['question_date'] if qid in existings: continue haystack_sess_ids = entry['haystack_session_ids'] haystack_topics = [sess2topic.get(sid, {}).get('category', []) for sid in haystack_sess_ids] date_lookup = dict(zip(haystack_sess_ids, entry['haystack_dates'])) topic_lookup = dict(zip(haystack_sess_ids, haystack_topics)) # Build ChatHistory: lazily from episodic store (two-stage) or from raw data (legacy) if episodic_store is not None: # Two-stage mode: start with full haystack loaded from episodic store # (Stage 1 will narrow this down before Stage 2 runs) raw_sessions = episodic_store.get_raw_sessions( haystack_sess_ids, date_lookup, topic_lookup ) chat_history = ChatHistory(sessions=raw_sessions) else: chat_history = ChatHistory({ "haystack_dates": entry['haystack_dates'], "haystack_session_ids": entry['haystack_session_ids'], "haystack_sessions": entry['haystack_sessions'], "haystack_topics": haystack_topics, }) topic_set = set() for ht in haystack_topics: topic_set.update(ht) if args.user_profile: # user profile temp_qid = qid if '_q_' in qid: temp_qid = qid.split("_q_")[0] user_profile = qid2profiles[temp_qid] agent = RetrievalAgent( chat_history, list(topic_set), user_profile=user_profile, debug=args.debug, vllm=args.vllm, vllm_reading=args.vllm_reading, tritonai=args.tritonai, nvidia=args.nvidia, n_chunks=args.n_chunks, topic_filter=args.topic_filter, no_time_filter=args.no_time_filter, semantic_store=None if args.no_semantic else semantic_store, episodic_store=episodic_store, hier_v2=args.hier_v2, hier_union=args.hier_union, hier_union_flat_k=args.hier_union_flat_k, no_early_answer=args.no_early_answer, ) else: agent = RetrievalAgent( chat_history, list(topic_set), debug=args.debug, vllm=args.vllm, vllm_reading=args.vllm_reading, tritonai=args.tritonai, nvidia=args.nvidia, n_chunks=args.n_chunks, topic_filter=args.topic_filter, no_time_filter=args.no_time_filter, semantic_store=None if args.no_semantic else semantic_store, episodic_store=episodic_store, hier_v2=args.hier_v2, hier_union=args.hier_union, hier_union_flat_k=args.hier_union_flat_k, no_early_answer=args.no_early_answer, ) try: if args.mode == 'embed': final_sess = embedding_search(chat_history, qid, top_k=top_k) attempt_record = [{"plan": {"answer": "none", "reason": "embedding retrieval only"}}] elif args.mode == 'keyword': keywords = generate_keywords(question, deployment_name, api_version, debug=args.debug, vllm=args.vllm, tritonai=args.tritonai, nvidia=args.nvidia) final_sess = keyword_search(chat_history, keywords=keywords) attempt_record = [{"plan": {"answer": "none", "reason": "keyword retrieval only"}}] else: # agent final_sess, attempt_record = agent.run( qid, question, q_date, top_k, model_info, max_loops=args.max_loops, semantic_ret_dict=semantic_ret_dict, haystack_sess_ids=haystack_sess_ids, date_lookup=date_lookup, topic_lookup=topic_lookup, ) if len(attempt_record) == 1 and "answer" in attempt_record[0]["plan"] and not ("none" in attempt_record[0]["plan"]["answer"].lower()): answer = attempt_record[0]["plan"]["answer"] token_budget = agent.get_token_budget() if args.mode == 'agent' else {} wall_time_sec = time.time() - item_start_time print(json.dumps({"q_idx": di, 'question_id': qid, 'question': entry['question'], 'answer': answer, 'n_retrieved': len(final_sess), 'wall_time_sec': round(wall_time_sec, 3)}, indent=4), flush=True) print(json.dumps({"q_idx": di, 'question_id': qid, 'hypothesis': answer, "attempt_record": attempt_record, "token_budget": token_budget, "wall_time_sec": wall_time_sec}), file=out_f, flush=True) else: if len(final_sess) > top_k and retrieved_log_file is not None: final_top_k_sess, _ = filter_out_by_embedding(final_sess, qid=qid, top_k=top_k) retrieved_str = final_top_k_sess.to_prompt(granularity="session", _format="json") else: retrieved_str = final_sess.to_prompt(granularity="session", _format="json") if args.answer_prompt_v2: answer_prompt_template = ( "You are answering a question using a list of chat-session transcripts between the user and an assistant.\n" "\n" "How to answer:\n" "1. Scan ALL retrieved sessions in chronological order. The SESSION DATE on each transcript is when that conversation occurred. The Current Date below is when the question was asked, not when events happened.\n" "2. Identify every session containing a candidate fact. If sessions conflict, prefer the most RECENT session that addresses the same fact (knowledge update).\n" "3. For aggregation questions ('how many', 'list all', 'between X and Y'), enumerate matches across ALL relevant sessions; do not stop at the first.\n" "4. For temporal queries ('last Friday', 'two weeks ago'), resolve the relative date against the SESSION DATE of the session that uses that phrase, not the Current Date.\n" "5. If the retrieved sessions do NOT contain the answer, reply exactly 'Insufficient information in retrieved sessions.' Do not fabricate.\n" "6. Be terse: state the direct answer first, then one short sentence citing the session date(s) you relied on.\n" "\n" "Chat history sessions:\n" "\n" "{}\n" "\n" "Current Date: {}\n" "Question: {}\n" "Answer:" ) else: answer_prompt_template = "I will give you several chat history sessions between you and a user. Please answer the question given the information.\n\n\nChat history sessions:\n\n{}\n\nCurrent Date: {}\nQuestion: {}\nAnswer:" answer_prompt = answer_prompt_template.format(retrieved_str, entry['question_date'], entry['question']) completion = llm_call( deployment_name, api_version, answer_prompt, debug=args.debug, vllm=args.vllm, tritonai=args.tritonai, nvidia=args.nvidia, ) answer = (completion.choices[0].message.content or "").strip() if args.mode == 'agent': agent._track_usage('final_answer', completion) token_budget = agent.get_token_budget() if args.mode == 'agent' else {} retrieval_metric = {} if len(final_sess) > 0 and retrieved_log_file is not None: sess_sorted = embedding_search(final_sess, qid, top_k=20) sess_id_sorted = sess_sorted.get_session_ids() for topk in [5, 10, 20, 30]: recall_any, recall_all = evaluate_retrieval(sess_id_sorted[:topk], entry['answer_session_ids']) retrieval_metric.update({ 'recall_any@{}'.format(topk): recall_any, 'recall_all@{}'.format(topk): recall_all }) retrieval_metric_list.append(retrieval_metric) print_average_metrics(retrieval_metric_list) print(json.dumps({"q_idx": di, 'n_prompt_tok': completion.usage.prompt_tokens, 'n_completion_tok': completion.usage.completion_tokens, 'hypothesis': answer, 'wall_time_sec': round(time.time() - item_start_time, 3)}), flush=True) print(json.dumps({"q_idx": di, 'question_id': qid, 'hypothesis': answer, 'n_prompt_tok': completion.usage.prompt_tokens, 'n_completion_tok': completion.usage.completion_tokens, "attempt_record": attempt_record, "retrieved_sess_ids": final_sess.get_session_ids(), "retrieval_metric": retrieval_metric, "token_budget": token_budget, "wall_time_sec": time.time() - item_start_time}), file=out_f, flush=True) except Exception as e: print(f"[ERROR] q_idx={di} qid={qid} failed: {e}", flush=True) continue ############# save cache ########################## if not args.no_save_cache: with open(plan_cache_file, "w") as fw: json.dump(qid2plan, fw, indent=2) with open(veri_reading_log_file, "w") as fw: json.dump(qid2rel_sess_ids, fw, indent=2)