| import argparse |
| import os |
| import re |
| import json |
| import time |
| from json import JSONDecodeError |
| from datetime import datetime, timedelta |
| from typing import List, Dict, Any |
|
|
| from openai import OpenAI |
| try: |
| from openai import AzureOpenAI |
| from azure.identity import ChainedTokenCredential, AzureCliCredential, ManagedIdentityCredential, get_bearer_token_provider |
| |
| AZURE_OAUTH_SCOPE = os.environ.get("AZURE_OAUTH_SCOPE", "") |
| if AZURE_OAUTH_SCOPE: |
| credential = get_bearer_token_provider(ChainedTokenCredential( |
| AzureCliCredential(), |
| ManagedIdentityCredential(), |
| ), AZURE_OAUTH_SCOPE) |
| else: |
| credential = None |
| except ImportError: |
| AzureOpenAI = None |
| credential = None |
|
|
| from model_zoo import model_zoo |
| from memory import EpisodicMemoryStore, SemanticMemoryStore |
| try: |
| import tiktoken |
| except ImportError: |
| tiktoken = None |
|
|
| try: |
| from transformers import AutoTokenizer, PreTrainedTokenizerBase |
| except ImportError: |
| AutoTokenizer = None |
| PreTrainedTokenizerBase = () |
| from collections import defaultdict |
|
|
| def get_hf_tokenizer_for_vllm(model_name: str): |
| return AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True) |
|
|
|
|
| |
| endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT", "") |
|
|
| |
| TRITONAI_BASE_URL = os.environ.get("LITELLM_BASE_URL", "") |
|
|
| |
| qid2plan = {} |
| if 'plan_cache' in os.environ: |
| plan_cache_file = os.environ['plan_cache'] |
| if os.path.exists(plan_cache_file): |
| qid2plan = json.load(open(plan_cache_file)) |
| else: |
| plan_cache_file = 'response_cache/qa/evolv_mem_v3_plan_cache_gpt5-1' |
|
|
| veri_reading_log_file = os.environ['reading_cache'] |
| qid2rel_sess_ids = {} |
| if os.path.exists(veri_reading_log_file): |
| qid2rel_sess_ids = json.load(open(veri_reading_log_file)) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| retrieved_log_file = None |
| if 'ret_cache' in os.environ: |
| retrieved_log_file = os.environ['ret_cache'] |
|
|
| print("loading existing retrieved results ...") |
| retrieved_data = [json.loads(line) for line in open(retrieved_log_file).readlines()] |
| retrieved_data_dict = {x['question_id']: x for x in retrieved_data} |
| valid_sess_set = set(json.load(open("dataset/all_sessions.json")).keys()) |
|
|
|
|
| def parse_json(response_content): |
| """Safely parse JSON content from a string response.""" |
| candidates = [] |
|
|
| if '```json' in response_content: |
| start_idx = response_content.find('```json') + 7 |
| end_idx = response_content.rfind('```') |
| if end_idx > start_idx: |
| candidates.append(response_content[start_idx:end_idx].strip()) |
|
|
| |
| brace_start = response_content.find('{') |
| brace_end = response_content.rfind('}') + 1 |
| if brace_start >= 0 and brace_end > brace_start: |
| candidates.append(response_content[brace_start:brace_end].strip()) |
|
|
| for json_block in candidates: |
| try: |
| result = json.loads(json_block) |
| return result |
| except (JSONDecodeError, ValueError): |
| continue |
|
|
| print(f"[Warning] Failed to decode JSON from response (all strategies failed)") |
| print(f"[Debug] Raw response content (truncated): {response_content[:500]}") |
| return {} |
|
|
|
|
| def _retryable_status(e): |
| |
| status = getattr(e, "status_code", None) or getattr(e, "http_status", None) |
| if status is not None: |
| return int(status) |
| resp = getattr(e, "response", None) |
| if resp is not None and getattr(resp, "status_code", None) is not None: |
| return int(resp.status_code) |
| |
| msg = str(e).lower() |
| if "429" in msg or "rate limit" in msg: |
| return 429 |
| if "500" in msg or "internal server error" in msg: |
| return 500 |
| if "503" in msg or "API Configuration unavailable" in msg: |
| return 503 |
| return None |
|
|
|
|
| MAX_CONTEXT_TOKENS = 272_000 |
| |
|
|
|
|
| class CharacterEncoder: |
| """Conservative fallback when tokenizer packages are unavailable.""" |
|
|
| def encode(self, text, **kwargs): |
| return list(text) |
|
|
| def decode(self, toks): |
| return "".join(toks) |
|
|
|
|
| def _get_encoder(model_name: str): |
| """ |
| Return a token encoder for the given model name. |
| Cached to avoid reloading HF tokenizers on every call. |
| """ |
| |
| |
| if AutoTokenizer is not None and any(k in model_name for k in ["Qwen3", "Qwen/", "Qwen"]): |
| try: |
| return AutoTokenizer.from_pretrained( |
| "Qwen/Qwen3-30B-A3B-Instruct-2507", |
| trust_remote_code=True, |
| use_fast=False |
| ) |
| except Exception as e: |
| print(f"[WARN] Failed to load Qwen tokenizer: {e}. Falling back to tiktoken.") |
|
|
| |
| if tiktoken is not None: |
| try: |
| return tiktoken.encoding_for_model(model_name) |
| except Exception: |
| |
| return tiktoken.get_encoding("cl100k_base") |
| return CharacterEncoder() |
|
|
|
|
| def _truncate_to_tokens(text, enc, max_tokens) -> str: |
| """ |
| Truncates text to the last `max_tokens`. |
| Compatible with both tiktoken and Hugging Face AutoTokenizers. |
| """ |
| |
| if PreTrainedTokenizerBase and isinstance(enc, PreTrainedTokenizerBase): |
| |
| |
| toks = enc.encode(text, add_special_tokens=False) |
| |
| else: |
| toks = enc.encode(text, disallowed_special=()) |
|
|
| if len(toks) <= max_tokens: |
| return text |
|
|
| |
| toks = toks[-max_tokens:] |
| |
| return enc.decode(toks) |
|
|
|
|
| def truncate_chat_prompt(tokenizer, messages, max_context, max_output, overhead=256): |
| |
| prompt_text = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| input_ids = tokenizer(prompt_text, add_special_tokens=False).input_ids |
|
|
| budget = max_context - max_output - overhead |
| if budget < 0: |
| raise ValueError("max_output_tokens + overhead exceeds max_context_tokens") |
|
|
| if len(input_ids) > budget: |
| input_ids = input_ids[-budget:] |
| prompt_text = tokenizer.decode(input_ids, skip_special_tokens=False) |
|
|
| return prompt_text |
| |
|
|
| def llm_call(deployment_name: str, |
| api_version: str, |
| _prompt: str, |
| max_context_tokens: int = MAX_CONTEXT_TOKENS, |
| max_output_tokens: int = 1024, |
| extra_overhead_tokens: int = 32, |
| debug: bool = False, |
| vllm: bool = False, |
| tritonai: bool = False, |
| nvidia: bool = False): |
| if nvidia: |
| client = OpenAI( |
| api_key=os.getenv("NV_API_KEY"), |
| base_url="https://inference-api.nvidia.com/v1", |
| ) |
| elif tritonai: |
| client = OpenAI( |
| api_key=os.getenv("TRITONAI_API_KEY"), |
| base_url=TRITONAI_BASE_URL, |
| ) |
| max_context_tokens = 131_072 |
| |
| if max_output_tokens < 4096: |
| max_output_tokens = 4096 |
| elif vllm: |
| |
| vllm_base_url = os.getenv("VLLM_BASE_URL", "http://localhost:8000/v1") |
| vllm_api_key = os.getenv("VLLM_API_KEY", "EMPTY") |
| client = OpenAI( |
| base_url=vllm_base_url, |
| api_key=vllm_api_key, |
| ) |
| |
| |
| deployment_name = os.getenv("VLLM_MODEL_NAME", deployment_name) |
| |
| max_context_tokens = 131_072 |
| elif debug: |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
| else: |
| client = AzureOpenAI( |
| azure_endpoint=endpoint, |
| azure_ad_token_provider=credential, |
| api_version=api_version, |
| ) |
|
|
| enc = _get_encoder(deployment_name) |
|
|
| |
| budget = max_context_tokens - max_output_tokens - extra_overhead_tokens |
| if budget < 0: |
| raise ValueError("max_output_tokens + overhead exceeds max_context_tokens") |
|
|
| prompt_truncated = _truncate_to_tokens(_prompt, enc, budget) |
|
|
| |
| if nvidia or tritonai: |
| prompt_truncated = prompt_truncated.encode('utf-8', errors='replace').decode('utf-8', errors='replace') |
| prompt_truncated = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', prompt_truncated) |
| |
| json.dumps(prompt_truncated) |
|
|
| |
| msg_role = "user" if (tritonai or nvidia) else "system" |
| kwargs = { |
| 'model': deployment_name, |
| 'messages':[ |
| {"role": msg_role, "content": prompt_truncated} |
| ] |
| } |
|
|
| while True: |
| try: |
| completion = client.chat.completions.create(**kwargs) |
| break |
| except Exception as e: |
| from openai import APITimeoutError as _APITimeoutError |
| if isinstance(e, _APITimeoutError): |
| print(f"[WARN] APITimeoutError from LLM; sleeping 30s then retrying...", flush=True) |
| time.sleep(30) |
| continue |
| st = _retryable_status(e) |
| |
| retryable = (429, 500, 503, 403) + ((404,) if tritonai else ()) |
| if st in retryable: |
| print( |
| f"[WARN] q_idx={di} HTTP {st} from LLM; sleeping 60s then retrying...", |
| flush=True |
| ) |
| time.sleep(60) |
| continue |
| |
| print('One exception captured', repr(e), flush=True) |
| raise |
|
|
| |
| |
| return completion |
|
|
| def custom_to_iso8601(time_str): |
| """ |
| Convert '2023/04/10 (Mon) 23:07' to '2023-04-10T23:07:00' |
| """ |
| |
| clean = time_str.split('(')[0].strip() + ' ' + time_str.split(')')[-1].strip() |
| |
| dt = datetime.strptime(clean, "%Y/%m/%d %H:%M") |
| |
| return dt.isoformat() |
|
|
|
|
| def evaluate_retrieval(recalled_docs, correct_docs, k=10): |
| |
| recall_any = float(any(doc in recalled_docs for doc in correct_docs)) |
| recall_all = float(all(doc in recalled_docs for doc in correct_docs)) |
| return recall_any, recall_all |
|
|
|
|
| def print_average_metrics(retrieval_metric_list): |
| metric_sums = defaultdict(float) |
| metric_counts = defaultdict(int) |
|
|
| for metric in retrieval_metric_list: |
| for k, v in metric.items(): |
| metric_sums[k] += v |
| metric_counts[k] += 1 |
|
|
| print("\t\t Average metrics:") |
| for k in sorted(metric_sums): |
| avg = metric_sums[k] / metric_counts[k] |
| print(f"\t\t {k}: {avg:.4f}") |
|
|
|
|
| |
| prompt_path = "prompts/agentic_retrieval_prompt.txt" |
| with open(prompt_path, "r", encoding="utf-8") as f: |
| stg_prompt = f.read() |
|
|
|
|
| class ChatHistory: |
| def __init__(self, data: Dict[str, Any] = None, sessions: List = None): |
| assert not (data is not None and sessions is not None), "ChatHistory: Only one of data or sessions may be provided." |
| |
| if data is not None: |
| self.raw_data = data |
| self.sessions = [] |
| self._parse_sessions() |
| elif sessions is not None: |
| self.sessions = sessions |
| self.messages = [] |
| for sess in self.sessions: |
| session_id = sess['session_id'] |
| timestamp = sess['timestamp'] |
| for turn_idx, msg in enumerate(sess['session']): |
| entry = { |
| "role": msg.get("role"), |
| "content": msg.get("content"), |
| "session_id": session_id, |
| "turn_index": turn_idx, |
| "timestamp": timestamp, |
| "iso_datetime": timestamp.isoformat(), |
| "has_answer": msg.get("has_answer", False) |
| } |
| self.messages.append(entry) |
| else: |
| self.sessions = [] |
| self.messages = [] |
|
|
| def get_contents(self, granularity='turn', _format='json') -> list: |
| if granularity == "turn": |
| if _format == "json": |
| return [json.dumps(msg) for msg in self.messages] |
| else: |
| return [msg['content'] for msg in self.messages] |
| else: |
| if _format == "json": |
| return [json.dumps(session) for session in self.sessions] |
| else: |
| return [json.dumps({"role": session["role"], "content": session["content"]}) |
| for session in self.sessions] |
|
|
| def to_prompt(self, granularity='session', _format="json"): |
| history_str = "" |
| for session in self.sessions: |
| sess_str = json.dumps([{"role": x["role"], "content": x["content"]} for x in session['session']]) |
| history_str += f"Session Date: {session['session_date']}\nSession Content:\n{sess_str}\n" |
| return history_str |
|
|
| def get_session_ids(self): |
| return [s['session_id'] for s in self.sessions] |
| |
| @staticmethod |
| def _parse_date(date_str: str) -> datetime: |
| |
| |
| date_part, time_part = date_str.split('(')[0].strip(), date_str.split(')')[-1].strip() |
| dt = datetime.strptime(date_part + time_part, "%Y/%m/%d%H:%M") |
| return dt |
| |
| def _parse_sessions(self): |
| """ |
| Flattens sessions into a list of messages, each with ISO 8601 date, session ID, and turn index |
| """ |
| self.messages = [] |
| for date_str, session_id, session, topic in zip( |
| self.raw_data['haystack_dates'], |
| self.raw_data['haystack_session_ids'], |
| self.raw_data['haystack_sessions'], |
| self.raw_data['haystack_topics'] |
| ): |
| timestamp = self._parse_date(date_str) |
| for turn_idx, msg in enumerate(session): |
| entry = { |
| "role": msg.get("role"), |
| "content": msg.get("content"), |
| "session_id": session_id, |
| "turn_index": turn_idx, |
| "timestamp": timestamp, |
| "iso_datetime": timestamp.isoformat(), |
| "session_date": date_str, |
| "has_answer": msg.get("has_answer", False) |
| } |
| self.messages.append(entry) |
| self.sessions.append({ |
| "session_date": date_str, |
| "timestamp": timestamp, |
| "session_id": session_id, |
| "session": session, |
| "topic": topic, |
| }) |
| |
| |
| |
|
|
| def __len__(self): |
| return len(self.sessions) |
| |
| def __getitem__(self, idx) -> Dict[str, any]: |
| return self.sessions[idx] |
|
|
| def get_item_by_index(self, idx): |
| if isinstance(idx, range) or isinstance(idx, list): |
| max_idx = len(self.sessions) |
| valid_indices = [i for i in idx if 0 <= i < max_idx] |
| selected_sessions = [self.sessions[i] for i in valid_indices] |
| return ChatHistory(sessions=selected_sessions) |
| else: |
| raise ValueError("Input must be a list or range of indices.") |
|
|
| def get_item_by_session_ids(self, sess_set): |
| if not isinstance(sess_set, set): |
| sess_set = set(sess_set) |
| new_sessions = [] |
| for sess in self.sessions: |
| if sess['session_id'] in sess_set: |
| new_sessions.append(sess) |
|
|
| return ChatHistory(sessions=new_sessions) |
|
|
| def get_item_by_ranked_session(self, sess_id_sorted): |
| new_sessions = [] |
| for sess_id in sess_id_sorted: |
| for sess in self.sessions: |
| if sess['session_id'] in sess_id: |
| new_sessions.append(sess) |
|
|
| return ChatHistory(sessions=new_sessions) |
|
|
| def get_item_by_topics(self, topics): |
| new_sessions = [] |
| new_sess_ids = set() |
| for sess in self.sessions: |
| for tp in sess['topic']: |
| if tp in topics and sess['session_id'] not in new_sess_ids: |
| new_sessions.append(sess) |
| new_sess_ids.add(sess['session_id']) |
| break |
|
|
| return ChatHistory(sessions=new_sessions) |
|
|
| def merge_rel_sess(self, new_sessions): |
| |
| all_sessions = {s["session_id"]: s for s in self.sessions} |
|
|
| |
| for s in new_sessions.sessions: |
| if s["session_id"] not in all_sessions: |
| all_sessions[s["session_id"]] = s |
|
|
| |
| merged_raw_data = { |
| "haystack_dates": [s["session_date"] for k, s in all_sessions.items()], |
| "haystack_session_ids": [s["session_id"] for k, s in all_sessions.items()], |
| "haystack_sessions": [s["session"] for k, s in all_sessions.items()], |
| "haystack_topics": [s["topic"] for k, s in all_sessions.items()], |
| } |
| self.sessions = ChatHistory(merged_raw_data) |
|
|
|
|
|
|
| def generate_keywords(question: str, deployment_name, api_version, debug=False, vllm=None, |
| tritonai=False, nvidia=False) -> List[str]: |
| |
| with open('prompts/keyword_search_prompt.txt') as f: |
| prompt_template = f.read() |
|
|
| prompt = prompt_template + question |
| |
| completion = llm_call( |
| deployment_name, |
| api_version, |
| prompt, |
| debug=debug, |
| vllm=vllm, |
| tritonai=tritonai, |
| nvidia=nvidia, |
| ) |
| |
| response_content = (completion.choices[0].message.content or "").strip() |
| result = parse_json(response_content) |
| keywords = result["keywords"] if "keywords" in result else [] |
| return keywords |
|
|
| def keyword_search(chat_history: ChatHistory, keywords: list): |
| print(f"\t\t** keyword search **: {keywords}") |
| |
| start_time = time.time() |
| matched_msgs = [ |
| msg for msg in chat_history.messages |
| if any(kw.lower() in (msg.get("content") or "").lower() for kw in keywords) |
| ] |
| end_time = time.time() |
| execution_time = end_time - start_time |
| |
| if matched_msgs: |
| new_sess_ids = set() |
| for msg in matched_msgs: |
| key = msg["session_id"] |
| new_sess_ids.add(key) |
| new_chat_history = chat_history.get_item_by_session_ids(new_sess_ids) |
| else: |
| new_chat_history = ChatHistory() |
|
|
| return new_chat_history |
|
|
|
|
| def is_turn_id(text): |
| pattern = r'_\d+$' |
| return bool(re.search(pattern, text)) |
|
|
|
|
| def embedding_search(chat_history: ChatHistory, qid: str, top_k: int = 50, exclude_sess=None): |
| print("\t\t** embedding based retrieval **") |
| new_sess_ids = [] |
|
|
| curr_all_sess = set(chat_history.get_session_ids()) |
| if exclude_sess: |
| curr_all_sess = curr_all_sess - set(exclude_sess) |
|
|
| for item in retrieved_data_dict[qid]["retrieval_results"]["ranked_items"]: |
| if item["corpus_id"] in valid_sess_set: |
| sid = item["corpus_id"] |
| else: |
| tokens = item["corpus_id"].split("_") |
|
|
| if "_turn" in item["corpus_id"]: |
| sid = item["corpus_id"].split("_turn")[0] |
| elif "_fact" in item["corpus_id"]: |
| sid = item["corpus_id"].split("_fact")[0] |
| elif "noans" in item["corpus_id"]: |
| sid = item["corpus_id"].replace("noans", "answer") |
| elif is_turn_id(item["corpus_id"]): |
| sid = "_".join(tokens[:-1]) |
| else: |
| sid = item["corpus_id"] |
|
|
| if sid not in valid_sess_set: |
| print(item["corpus_id"], sid) |
|
|
| assert sid in valid_sess_set |
|
|
| if sid in curr_all_sess: |
| new_sess_ids.append(sid) |
| if len(new_sess_ids) == top_k: |
| break |
| new_chat_history = chat_history.get_item_by_ranked_session(new_sess_ids) |
| return new_chat_history |
|
|
|
|
| def filter_out_by_embedding(chat_history: ChatHistory, qid: str, top_k: int = 50): |
| print("\t\t** [filter_out] embedding based retrieval - loading existing results ...") |
| new_sess_ids = [] |
| curr_all_sess = set(chat_history.get_session_ids()) |
| for item in retrieved_data_dict[qid]["retrieval_results"]["ranked_items"]: |
| if item["corpus_id"] in valid_sess_set: |
| sid = item["corpus_id"] |
| else: |
| tokens = item["corpus_id"].split("_") |
|
|
| if "_turn" in item["corpus_id"]: |
| sid = item["corpus_id"].split("_turn")[0] |
| elif "_fact" in item["corpus_id"]: |
| sid = item["corpus_id"].split("_fact")[0] |
| elif "noans" in item["corpus_id"]: |
| sid = item["corpus_id"].replace("noans", "answer") |
| elif is_turn_id(item["corpus_id"]): |
| sid = "_".join(tokens[:-1]) |
| else: |
| sid = item["corpus_id"] |
|
|
| if sid not in valid_sess_set: |
| print(item["corpus_id"], sid) |
|
|
| assert sid in valid_sess_set |
|
|
| if sid in curr_all_sess: |
| new_sess_ids.append(sid) |
| if len(new_sess_ids) == top_k: |
| break |
| new_chat_history = chat_history.get_item_by_ranked_session(new_sess_ids) |
| return new_chat_history, 0.0 |
|
|
|
|
| def flat_embedding_top_k_ids(qid: str, haystack_sess_ids: List[str], top_k: int) -> List[str]: |
| """ |
| Pull the top_k session IDs from the global GTE retrieval cache (retrieved_data_dict), |
| constrained to the question's haystack. Mirrors embedding_search() but operates on |
| IDs only (no ChatHistory). Used by hier_union to widen the Stage-2 pool. |
| """ |
| haystack_set = set(haystack_sess_ids) |
| ids: List[str] = [] |
| for item in retrieved_data_dict[qid]["retrieval_results"]["ranked_items"]: |
| cid = item["corpus_id"] |
| if cid in valid_sess_set: |
| sid = cid |
| elif "_turn" in cid: |
| sid = cid.split("_turn")[0] |
| elif "_fact" in cid: |
| sid = cid.split("_fact")[0] |
| elif "noans" in cid: |
| sid = cid.replace("noans", "answer") |
| elif is_turn_id(cid): |
| sid = "_".join(cid.split("_")[:-1]) |
| else: |
| sid = cid |
| if sid in haystack_set and sid not in ids: |
| ids.append(sid) |
| if len(ids) == top_k: |
| break |
| return ids |
|
|
|
|
| def semantic_embedding_search( |
| qid: str, |
| haystack_sess_ids: List[str], |
| semantic_retrieved_dict: dict, |
| top_k: int = 50, |
| ) -> List[str]: |
| """ |
| Like embedding_search() but reads from the pre-computed semantic-gte retrieval cache. |
| Returns an ordered list of up to top_k session IDs from the haystack. |
| """ |
| print("\t\t** semantic embedding retrieval **") |
| haystack_set = set(haystack_sess_ids) |
| ranked_ids: List[str] = [] |
| for item in semantic_retrieved_dict[qid]["retrieval_results"]["ranked_items"]: |
| sid = item["corpus_id"] |
| if sid in haystack_set and sid not in ranked_ids: |
| ranked_ids.append(sid) |
| if len(ranked_ids) == top_k: |
| break |
| return ranked_ids |
|
|
|
|
| def time_filter(chat_history: ChatHistory, start_date: str, end_date: str) -> ChatHistory: |
| |
| start_time = time.time() |
| try: |
| start = datetime.fromisoformat(start_date) |
| end = datetime.fromisoformat(end_date) |
| filtered_msgs = [msg for msg in chat_history.messages if start.date() <= msg["timestamp"].date() <= end.date()] |
| except Exception as e: |
| print("Converting date error: ", e) |
| filtered_msgs = [] |
| end_time = time.time() |
| execution_time = end_time - start_time |
|
|
| if filtered_msgs: |
| new_sess_ids = set() |
| for msg in filtered_msgs: |
| key = msg["session_id"] |
| new_sess_ids.add(key) |
| new_chat_history = chat_history.get_item_by_session_ids(new_sess_ids) |
| else: |
| new_chat_history = ChatHistory() |
|
|
| return new_chat_history |
|
|
|
|
| class RetrievalAgent: |
| def __init__( |
| self, |
| history: List[Dict], |
| topics: List[str], |
| user_profile: str = None, |
| debug: bool = False, |
| vllm: bool = False, |
| vllm_reading: bool = False, |
| tritonai: bool = False, |
| nvidia: bool = False, |
| n_chunks: int = 10, |
| topic_filter: bool = True, |
| no_time_filter: bool = False, |
| semantic_store: SemanticMemoryStore = None, |
| episodic_store: EpisodicMemoryStore = None, |
| hier_v2: bool = False, |
| hier_union: bool = False, |
| hier_union_flat_k: int = 20, |
| no_early_answer: bool = False, |
| ): |
| self.chat_history = history |
| self.user_profile = user_profile |
| self.topics = topics |
| self.rel_sess = ChatHistory() |
| self.evidence = [] |
| self.debug = debug |
| self.vllm = vllm |
| self.vllm_reading = vllm_reading |
| self.tritonai = tritonai |
| self.nvidia = nvidia |
| self.no_time_filter = no_time_filter |
| self.n_chunks = n_chunks |
| self.topic_filter = topic_filter |
| self.semantic_store = semantic_store |
| self.episodic_store = episodic_store |
| self.hier_v2 = hier_v2 |
| self.hier_union = hier_union |
| self.hier_union_flat_k = hier_union_flat_k |
| self.no_early_answer = no_early_answer |
| self.token_budget = { |
| 'planning': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0}, |
| 'verification_reading': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0}, |
| 'is_answerable': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0}, |
| 'final_answer': {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0}, |
| } |
|
|
| if self.user_profile: |
| with open('prompts/read_and_extract_prompt.txt') as f: |
| self.read_prompt_template = f.read() |
| else: |
| with open('prompts/agentic_retrieval_prompt_wo_profile.txt') as f: |
| self.read_prompt_template = f.read() |
|
|
| def _track_usage(self, component: str, completion) -> None: |
| """Accumulate prompt/completion token counts for a named component.""" |
| usage = getattr(completion, 'usage', None) |
| if usage is None: |
| return |
| self.token_budget[component]['prompt_tokens'] += getattr(usage, 'prompt_tokens', 0) or 0 |
| self.token_budget[component]['completion_tokens'] += getattr(usage, 'completion_tokens', 0) or 0 |
| self.token_budget[component]['n_calls'] += 1 |
|
|
| def get_token_budget(self) -> dict: |
| """Return token_budget with an added 'total' entry.""" |
| total = {'prompt_tokens': 0, 'completion_tokens': 0, 'n_calls': 0} |
| for v in self.token_budget.values(): |
| total['prompt_tokens'] += v['prompt_tokens'] |
| total['completion_tokens'] += v['completion_tokens'] |
| total['n_calls'] += v['n_calls'] |
| return {**self.token_budget, 'total': total} |
|
|
| def is_answerable(self, question: str, question_date: str, retrieved_sess, evidence, model_info, context_str: str = None) -> bool: |
| context = "" |
| for k, v in evidence.items(): |
| for e in v: |
| context += f"{e}\n" |
|
|
| |
| sess_str = context_str if context_str is not None else retrieved_sess.to_prompt() |
|
|
| |
| profile_section = "" |
| if self.user_profile: |
| profile_section = f"\nUser Profile:\n{self.user_profile}\n" |
|
|
| ia_prompt_prefix = f""" |
| You are a decision-making agent tasked with determining when sufficient information has been gathered to answer a user's question. |
| |
| Your Task: |
| Analyze the provided question, current date, available memory context, and available evidence to make a binary decision: Answerable or Not answerable. If the information is not sufficient, explain what specific information is needed to provide hints for the next retrieval stage. |
| |
| Question: {question} |
| Current Date: {question_date} |
| {profile_section} |
| Memory Context: |
| """ |
|
|
| output_str = """ |
| Output (always JSON — choose fields per the rules above) |
| |
| Case 1 — Answerable: |
| { |
| "is_answerable": true, |
| "answer": "<concise answer grounded strictly in the Evidence>" |
| } |
| |
| Case 2 — Not answerable: |
| { |
| "is_answerable": false, |
| "info_needed": ["<specific missing detail 1>", "<specific missing detail 2>"] |
| } |
| """ |
| deployment_name, api_version = model_info |
|
|
| |
| |
| |
| enc = _get_encoder(deployment_name) |
|
|
| |
| if self.vllm or self.tritonai: |
| |
| model_max_ctx = 131_072 |
| else: |
| model_max_ctx = MAX_CONTEXT_TOKENS |
|
|
| max_output_tokens = 1024 |
| extra_overhead_tokens = 32 |
|
|
| |
| budget = model_max_ctx - max_output_tokens - extra_overhead_tokens |
| if budget <= 0: |
| raise ValueError( |
| f"max_output_tokens ({max_output_tokens}) + overhead " |
| f"({extra_overhead_tokens}) exceeds model_max_ctx ({model_max_ctx})" |
| ) |
|
|
| |
| prefix_tokens = enc.encode(ia_prompt_prefix, disallowed_special=()) |
| output_tokens = enc.encode(output_str, disallowed_special=()) |
| sess_tokens = enc.encode(sess_str, disallowed_special=()) |
|
|
| |
| available_for_sess = budget - len(prefix_tokens) - len(output_tokens) |
| |
| if available_for_sess <= 0: |
| |
| truncated_sess_str = "" |
| else: |
| if len(sess_tokens) > available_for_sess: |
| |
| truncated_sess_tokens = sess_tokens[-available_for_sess:] |
| truncated_sess_str = enc.decode(truncated_sess_tokens) |
| else: |
| truncated_sess_str = sess_str |
|
|
| ia_prompt = ia_prompt_prefix + truncated_sess_str + "\n" |
| |
| |
| |
| |
|
|
| completion = llm_call( |
| deployment_name, |
| api_version, |
| ia_prompt + output_str, |
| max_context_tokens=model_max_ctx, |
| max_output_tokens=max_output_tokens, |
| extra_overhead_tokens=extra_overhead_tokens, |
| debug=self.debug, |
| vllm=self.vllm, |
| tritonai=self.tritonai, |
| nvidia=self.nvidia, |
| ) |
| self._track_usage('is_answerable', completion) |
| response_content = (completion.choices[0].message.content or "").strip() |
| print(f"\t\t[Agent] is_answerable: {response_content}") |
| result = parse_json(response_content) |
| if not result: |
| print("[Warning] Empty or invalid JSON in is_answerable() response.") |
| return {"is_answerable": False, "info_needed": ["Parsing failed"]} |
| return result |
| |
| def _read_and_verify(self, question: str, question_date: str, evidence: ChatHistory, n_chunks=10) -> ChatHistory: |
| |
| relevant_indices = [] |
| evidence_list = [] |
| max_idx = len(evidence) |
| for j in range(0, len(evidence), n_chunks): |
| chunk_range = range(j, j+n_chunks) |
| valid_indices = [i for i in chunk_range if 0 <= i < max_idx] |
| cur_chunk = evidence.get_item_by_index(valid_indices) |
| cur_chunk_sess = [[{"role": m["role"], "content": m["content"]} for m in sess['session']] |
| for sess in cur_chunk.sessions] |
| cur_chunk_sess_date = [sess['session_date'] for sess in cur_chunk.sessions] |
| sess_input_str = "\n".join([ |
| f"### Session Index: {i}\n### Session Date: {sess_date}\n\n{json.dumps(sess)}\n" |
| for i, (sess, sess_date) in enumerate(zip(cur_chunk_sess, cur_chunk_sess_date)) |
| ]) |
| _prompt = self.read_prompt_template + f"## Question: {question}\n## Question Date: {question_date}\n## Session list:\n\n{sess_input_str}\nNow, identify **only the sessions strictly necessary to answer the question**." |
| completion = llm_call( |
| deployment_name, |
| api_version, |
| _prompt, |
| debug=self.debug, |
| vllm=self.vllm or self.vllm_reading, |
| nvidia=self.nvidia, |
| ) |
| self._track_usage('verification_reading', completion) |
| response_content = (completion.choices[0].message.content or "").strip() |
|
|
| print(f"\t\t {valid_indices[0]}~{valid_indices[-1]}: response: {response_content.replace(chr(10), '')}") |
| try: |
| start_idx = response_content.rfind('{') |
| end_idx = response_content.rfind('}') + 1 |
| json_block = response_content[start_idx:end_idx] |
| result = json.loads(json_block) |
|
|
| if "index" in result and result['index'] and 'evidence' and result: |
| relevant_indices.extend([j + idx for idx in result['index']]) |
| evidence_list.extend(result['evidence']) |
| except Exception as e: |
| print(f"Error parsing LLM response: {e}") |
| |
| if relevant_indices: |
| return evidence.get_item_by_index(relevant_indices), evidence_list |
| else: |
| return ChatHistory(), [] |
|
|
| def _read_and_verify_with_cache(self, qid:str, pool): |
| relevant_sess_ids = [] |
|
|
| for sess in pool.sessions: |
| sess_id = sess['session_id'] |
| if sess_id in qid2rel_sess_ids[qid]: |
| relevant_sess_ids.append(sess_id) |
|
|
| if len(relevant_sess_ids) > 0: |
| return pool.get_item_by_session_ids(relevant_sess_ids), [] |
| else: |
| return ChatHistory(), [] |
| |
| def _plan(self, query: str, query_date: str, attempt_record: list, model_info) -> str: |
| if self.user_profile: |
| template = """ |
| ### User profile: {user_profile} |
| ### Chat history topics: {topics} |
| ### User query: {query} |
| ### User query date: {query_date} |
| ### Previous attempts: |
| {strategies_info} |
| """ |
| else: |
| template = """ |
| ### Chat history topics: {topics} |
| ### User query: {query} |
| ### User query date: {query_date} |
| ### Previous attempts: |
| {strategies_info} |
| """ |
| |
| if attempt_record: |
| strategies_info = "" |
| for loop_num, entry in enumerate(attempt_record): |
| strategies_info += f"\nloop_iteration: {loop_num+1}\n" |
| strategies_info += "\n".join(entry.get('step_logs', [])) |
| if 'n_retrieved_sess' in entry and 'evidence' in entry: |
| strategies_info += f"Retrieved {entry['n_retrieved_sess']} docs, observed_evidence: {entry['evidence']}" |
| if entry['n_retrieved_sess'] == 0: |
| strategies_info += f"Additional Instruction: Re-try without filter methods if the previous paln includes topics or time-filtering\n" |
| else: |
| strategies_info = "(No previous attempt exists)" |
| |
| if self.user_profile: |
| prompt_filled = stg_prompt + template.format( |
| user_profile=self.user_profile, |
| topics=",".join(self.topics), |
| query=query, |
| query_date=query_date, |
| strategies_info=strategies_info) |
| else: |
| prompt_filled = stg_prompt + template.format( |
| topics=",".join(self.topics), |
| query=query, |
| query_date=query_date, |
| strategies_info=strategies_info) |
|
|
| deployment_name, api_version = model_info |
| completion = llm_call( |
| deployment_name, |
| api_version, |
| prompt_filled, |
| debug=self.debug, |
| vllm=self.vllm, |
| tritonai=self.tritonai, |
| nvidia=self.nvidia, |
| ) |
| self._track_usage('planning', completion) |
| response_content = (completion.choices[0].message.content or "").strip() |
| _plan = parse_json(response_content) |
| if not _plan: |
| print("[Warning] Failed to parse plan JSON — retrying once.") |
| completion = llm_call( |
| deployment_name, |
| api_version, |
| prompt_filled, |
| debug=self.debug, |
| vllm=self.vllm, |
| tritonai=self.tritonai, |
| nvidia=self.nvidia, |
| ) |
| self._track_usage('planning', completion) |
| response_content = (completion.choices[0].message.content or "").strip() |
| _plan = parse_json(response_content) |
| if not _plan: |
| print("[Warning] Failed to parse plan JSON after retry — returning fallback plan.") |
| _plan = {"answer": "none", "reason": "invalid JSON response", "topics": [], "strategy": []} |
| return _plan |
|
|
| def _run_stage1( |
| self, |
| qid: str, |
| question: str, |
| question_date: str, |
| top_k: int, |
| model_info, |
| haystack_sess_ids: List[str], |
| date_lookup: Dict[str, str], |
| semantic_ret_dict: dict, |
| ) -> dict: |
| """ |
| Stage 1: retrieve and evaluate using semantic memory only (summaries + facts). |
| |
| Returns a dict with: |
| is_answerable : bool |
| answer : str | None (set when is_answerable is True) |
| candidate_ids : list[str] (top-K session IDs from semantic retrieval) |
| attempt_record : list |
| """ |
| print(f"\t[Stage 1] Semantic memory retrieval for qid={qid}") |
|
|
| |
| candidate_ids = semantic_embedding_search( |
| qid, haystack_sess_ids, semantic_ret_dict, top_k=top_k |
| ) |
|
|
| |
| if self.hier_v2: |
| print(f"\t[Stage 1 hier_v2] embedding-only candidates: {len(candidate_ids)}") |
| return { |
| "is_answerable": False, |
| "answer": None, |
| "candidate_ids": candidate_ids[:top_k], |
| "attempt_record": [{ |
| "stage": "semantic_v2", |
| "plan": {}, |
| "n_candidates": len(candidate_ids), |
| "candidate_ids": candidate_ids[:top_k], |
| }], |
| } |
|
|
| |
| plan = self._plan(question, question_date, [], model_info) |
| print(json.dumps(plan, indent=4), flush=True) |
|
|
| |
| if "answer" in plan and plan["answer"].lower() != "none": |
| return { |
| "is_answerable": True, |
| "answer": plan["answer"], |
| "candidate_ids": candidate_ids, |
| "attempt_record": [{"plan": plan, "stage": "semantic"}], |
| } |
|
|
| |
| keyword_ids: List[str] = [] |
| for step in plan.get("strategy", []): |
| if step.get("method") == "keyword": |
| kws = step.get("keywords", []) |
| matched = self.semantic_store.keyword_search(kws, haystack_sess_ids) |
| print(f"\t\t** semantic keyword search **: {kws} -> {len(matched)} matches") |
| keyword_ids.extend(sid for sid in matched if sid not in keyword_ids) |
|
|
| |
| for step in plan.get("strategy", []): |
| if self.no_time_filter: |
| break |
| if step.get("method") == "time_filter": |
| if "time_range" not in step or len(step["time_range"]) != 2: |
| continue |
| start_str, end_str = step["time_range"] |
| from datetime import datetime |
| try: |
| start_dt = datetime.fromisoformat(start_str) |
| end_dt = datetime.fromisoformat(end_str) |
| candidate_ids = [ |
| sid for sid in candidate_ids |
| if sid in date_lookup and |
| start_dt.date() <= EpisodicMemoryStore._parse_date(date_lookup[sid]).date() <= end_dt.date() |
| ] |
| keyword_ids = [ |
| sid for sid in keyword_ids |
| if sid in date_lookup and |
| start_dt.date() <= EpisodicMemoryStore._parse_date(date_lookup[sid]).date() <= end_dt.date() |
| ] |
| print(f"\t\t** semantic time_filter **: {start_str}..{end_str} -> " |
| f"{len(candidate_ids)} embed, {len(keyword_ids)} keyword") |
| except Exception as e: |
| print(f"\t\t[WARN] time_filter parse error: {e}") |
|
|
| |
| all_candidate_ids: List[str] = list(candidate_ids) |
| for sid in keyword_ids: |
| if sid not in all_candidate_ids: |
| all_candidate_ids.append(sid) |
|
|
| |
| all_candidate_ids = all_candidate_ids[:top_k] |
|
|
| if not all_candidate_ids: |
| print("\t[Stage 1] No candidates found in semantic memory.") |
| return { |
| "is_answerable": False, |
| "answer": None, |
| "candidate_ids": [], |
| "attempt_record": [{"plan": plan, "stage": "semantic", "n_candidates": 0}], |
| } |
|
|
| |
| semantic_context_str = self.semantic_store.to_prompt(all_candidate_ids, date_lookup) |
| print(f"\t[Stage 1] Built semantic context for {len(all_candidate_ids)} sessions " |
| f"({len(semantic_context_str)} chars)") |
|
|
| |
| accumulated_evidence = {"profile": [], "chat_clues": []} |
| answerable_response = self.is_answerable( |
| question, question_date, |
| retrieved_sess=None, |
| evidence=accumulated_evidence, |
| model_info=model_info, |
| context_str=semantic_context_str, |
| ) |
| print(f"\t[Stage 1] is_answerable: {answerable_response}") |
|
|
| attempt_record = [{ |
| "stage": "semantic", |
| "plan": plan, |
| "n_candidates": len(all_candidate_ids), |
| "candidate_ids": all_candidate_ids, |
| "is_answerable": answerable_response.get("is_answerable", False), |
| }] |
|
|
| if answerable_response.get("is_answerable"): |
| return { |
| "is_answerable": True, |
| "answer": answerable_response.get("answer"), |
| "candidate_ids": all_candidate_ids, |
| "attempt_record": attempt_record, |
| } |
|
|
| print(f"\t[Stage 1] Not answerable from semantic memory. " |
| f"Info needed: {answerable_response.get('info_needed', [])}") |
| return { |
| "is_answerable": False, |
| "answer": None, |
| "candidate_ids": all_candidate_ids, |
| "attempt_record": attempt_record, |
| } |
|
|
| def run(self, qid:str, question: str, question_date: str, top_k: int, model_info, max_loops=3, |
| semantic_ret_dict: dict = None, haystack_sess_ids: List[str] = None, |
| date_lookup: Dict[str, str] = None, topic_lookup: Dict[str, List[str]] = None): |
| accumulated_evidence = {"profile": [], "chat_clues": []} |
| attempt_record = [] |
| loop_num = 0 |
|
|
| |
| |
| |
| stage1_candidate_ids: List[str] = [] |
| if (self.semantic_store is not None |
| and semantic_ret_dict is not None |
| and haystack_sess_ids is not None): |
| stage1_result = self._run_stage1( |
| qid, question, question_date, top_k, model_info, |
| haystack_sess_ids, date_lookup or {}, semantic_ret_dict, |
| ) |
| attempt_record.extend(stage1_result["attempt_record"]) |
| stage1_candidate_ids = stage1_result["candidate_ids"] |
|
|
| if stage1_result["is_answerable"] and not self.no_early_answer: |
| print(f"\t[Stage 1] Answered from semantic memory.") |
| |
| if "answer" in stage1_result and stage1_result["answer"]: |
| attempt_record[0]["plan"] = { |
| **attempt_record[0].get("plan", {}), |
| "answer": stage1_result["answer"], |
| } |
| return ChatHistory(), attempt_record |
| if stage1_result["is_answerable"] and self.no_early_answer: |
| print(f"\t[Stage 1] is_answerable=True but --no_early_answer set; proceeding to Stage-2.") |
|
|
| |
| |
| |
| if self.hier_union and qid in retrieved_data_dict: |
| flat_ids = flat_embedding_top_k_ids(qid, haystack_sess_ids, self.hier_union_flat_k) |
| before = len(stage1_candidate_ids) |
| for sid in flat_ids: |
| if sid not in stage1_candidate_ids: |
| stage1_candidate_ids.append(sid) |
| print(f"\t[hier_union] semantic_top_k={before} + flat_top_{self.hier_union_flat_k}={len(flat_ids)} -> union={len(stage1_candidate_ids)}") |
| attempt_record.append({ |
| "stage": "hier_union", |
| "plan": {}, |
| "n_semantic": before, |
| "n_flat": len(flat_ids), |
| "n_union": len(stage1_candidate_ids), |
| }) |
|
|
| |
| if self.episodic_store is not None and stage1_candidate_ids: |
| print(f"\t[Stage 2] Loading episodic memory for " |
| f"{len(stage1_candidate_ids)} candidate sessions.") |
| raw_sessions = self.episodic_store.get_raw_sessions( |
| stage1_candidate_ids, date_lookup or {}, topic_lookup |
| ) |
| self.chat_history = ChatHistory(sessions=raw_sessions) |
| print(f"\t[Stage 2] Loaded {len(self.chat_history)} sessions into episodic pool.") |
|
|
| |
| |
| |
| |
| if self.hier_v2: |
| print(f"\t[hier_v2] bypassing agent loop; returning {len(self.chat_history)} candidate sessions as retrieved") |
| return self.chat_history, attempt_record |
|
|
| pool = self.chat_history |
| retrieved = ChatHistory() |
|
|
| while loop_num < max_loops: |
| loop_num += 1 |
| if loop_num == 1 and qid in qid2plan: |
| plan = qid2plan[qid] |
| else: |
| plan = self._plan(question, question_date, attempt_record, model_info) |
| qid2plan[qid] = plan |
| print(json.dumps(plan, indent=4), flush=True) |
|
|
| if "answer" in plan and not ("none" in plan["answer"].lower()): |
| print(f"{qid}\t{question}\t{plan['answer']}", flush=True) |
| return ChatHistory(), [{"plan": plan}] |
|
|
| |
| try: |
| candidates, step_logs = self._execute_strategy(pool, plan, question) |
| except Exception as e: |
| print(f"[Error] Failed during _execute_strategy: {e}") |
| candidates, step_logs = ChatHistory(), [f"Execution failed: {e}"] |
| |
| if len(candidates) == 0: |
| attempt_record.append({ |
| "loop_iteration": loop_num, |
| "plan": plan, |
| "evidence": accumulated_evidence, |
| "n_candidates_sess": len(candidates), |
| "n_verified_sess": 0, |
| "n_pool": len(pool), |
| "step_logs": step_logs, |
| }) |
| continue |
| else: |
| remaining = set(pool.get_session_ids()) - set(candidates.get_session_ids()) |
| pool = self.chat_history.get_item_by_session_ids(remaining) |
| |
| |
| if qid in qid2rel_sess_ids: |
| verified, evidence_list = self._read_and_verify_with_cache(qid, candidates) |
| else: |
| verified, evidence_list = self._read_and_verify(question, question_date, candidates, n_chunks=self.n_chunks) |
| qid2rel_sess_ids[qid] = verified.get_session_ids() |
|
|
| if len(verified) == 0: |
| attempt_record.append({ |
| "loop_iteration": loop_num, |
| "plan": plan, |
| "evidence": accumulated_evidence, |
| "n_candidates_sess": len(candidates), |
| "candidates_sess_ids": candidates.get_session_ids(), |
| "n_verified_sess": len(verified), |
| "verified_sess_ids": [], |
| "n_pool": len(pool), |
| "step_logs": step_logs, |
| }) |
| continue |
| else: |
| retrieved.merge_rel_sess(verified) |
|
|
| for ev in evidence_list: |
| if ev not in accumulated_evidence['chat_clues']: |
| accumulated_evidence['chat_clues'].append(ev) |
|
|
| attempt_record.append({ |
| "loop_iteration": loop_num, |
| "plan": plan, |
| "evidence": accumulated_evidence, |
| "n_candidates_sess": len(candidates), |
| "candidates_sess_ids": candidates.get_session_ids(), |
| "n_verified_sess": len(verified), |
| "verified_sess_ids": verified.get_session_ids(), |
| "n_retrieved_sess": len(retrieved), |
| "retrieved_sess_ids": retrieved.get_session_ids(), |
| "n_pool": len(pool), |
| "step_logs": step_logs, |
| }) |
|
|
| |
| if len(retrieved) > top_k: |
| retrieved, _ = filter_out_by_embedding(retrieved, qid=qid, top_k=top_k) |
|
|
| answerable_response = self.is_answerable(question, question_date, retrieved, accumulated_evidence, model_info) |
| if answerable_response["is_answerable"]: |
| plan["answer"] = answerable_response["answer"] |
| return retrieved, attempt_record |
|
|
| if len(pool) == 0: |
| break |
|
|
| return retrieved, attempt_record |
|
|
| def _execute_strategy(self, pool, plan, question): |
| step_logs: List[str] = [] |
|
|
| |
| if self.topic_filter and len(plan.get('topics', [])) > 0: |
| pool = pool.get_item_by_topics(plan['topics']) |
| |
| retrieved = ChatHistory() |
|
|
| strategy = plan["strategy"] |
| for step in strategy: |
| method = step.get("method") |
| if method == "keyword": |
| kws = step.get("keywords", []) |
| matched = keyword_search(pool, kws) |
| step_logs.append(f"Method: keyword - matched {len(matched)}/{len(pool)} using {kws}") |
| if len(matched) > 0: |
| retrieved.merge_rel_sess(matched) |
| elif method == "embedding": |
| top_k = 50 |
| matched = embedding_search(pool, qid, top_k=top_k) |
| step_logs.append(f"Method: embedding - top_k={top_k}, matched {len(matched)}/{len(pool)}") |
| if len(matched) > 0: |
| retrieved.merge_rel_sess(matched) |
| elif method == "time_filter": |
| if self.no_time_filter: |
| step_logs.append(f"Method: time_filter - skipped (--no_time_filter)") |
| continue |
| if 'time_range' not in step or len(step['time_range']) != 2: |
| continue |
| if len(retrieved) > 0: |
| retrieved = time_filter(retrieved, start_date=step['time_range'][0], end_date=step['time_range'][1]) |
| else: |
| retrieved = time_filter(pool, start_date=step['time_range'][0], end_date=step['time_range'][1]) |
| step_logs.append(f"Method: time_filter - kept {len(retrieved)}/{len(pool)} in {step['time_range'][0]}..{step['time_range'][1]}") |
| |
| |
| else: |
| step_logs.append(f"unknown method: {method}") |
|
|
| if len(retrieved) > 100: |
| top_k = 100 |
| retrieved = embedding_search(retrieved, qid, top_k=top_k) |
| step_logs.append( |
| f"too many sess ({len(pool)}) - embedding top_k={top_k} matched {len(retrieved)}/{len(pool)}" |
| ) |
|
|
| return retrieved, step_logs |
|
|
| def merge_rel_sess(self, new_sessions: ChatHistory): |
| |
| all_sessions = {s["session_id"]: s for s in self.rel_sess.sessions} |
|
|
| |
| for s in new_sessions.sessions: |
| if s["session_id"] not in all_sessions: |
| all_sessions[s["session_id"]] = s |
| |
| |
| |
| |
|
|
| |
| merged_raw_data = { |
| "haystack_dates": [s["session_date"] for k, s in all_sessions.items()], |
| "haystack_session_ids": [s["session_id"] for k, s in all_sessions.items()], |
| "haystack_sessions": [s["session"] for k, s in all_sessions.items()], |
| } |
| self.rel_sess = ChatHistory(merged_raw_data) |
|
|
|
|
| def merge_evidence(self, new_evidence: list): |
| self.evidence = self.evidence + new_evidence |
| print(f"\t\t Updated evidence: {self.evidence}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--in_file', type=str, required=True) |
| parser.add_argument('--out_file', type=str, required=True) |
| parser.add_argument('--model_name', type=str, required=True) |
| parser.add_argument('--top_k', type=int, required=True) |
| parser.add_argument('--debug', action='store_true', default=False) |
| parser.add_argument('--vllm', action='store_true', default=False) |
| parser.add_argument('--tritonai', action='store_true', default=False, |
| help='Use OpenAI-compatible LiteLLM proxy for non-reading LLM calls (set TRITONAI_API_KEY)') |
| parser.add_argument('--nvidia', action='store_true', default=False, |
| help='Use NVIDIA inference API (set NV_API_KEY)') |
| parser.add_argument('--vllm_reading', action='store_true', default=False, |
| help='Use vLLM only for verification reading; all other LLM calls use the proprietary API') |
| parser.add_argument('--n_chunks', type=int, default=10, |
| help='Number of sessions per verification-reading LLM call (default 10)') |
| parser.add_argument('--max_loops', type=int, default=3, |
| help='Maximum retrieval/planning loops for agent mode (default 3)') |
| parser.add_argument('--mode', type=str, default="agent", choices=['agent', 'embed', 'keyword']) |
| parser.add_argument('--topic_filter', type=bool, default=True) |
| parser.add_argument('--user_profile', action=argparse.BooleanOptionalAction, default=True, |
| help='Include user profile in prompts (default: True, use --no-user_profile to disable)') |
| parser.add_argument('--no_semantic', action='store_true', default=False, |
| help='Skip semantic Stage 1; run episodic Stage 2 only on all haystack sessions') |
| parser.add_argument('--no_time_filter', action='store_true', default=False, |
| help='Disable time_filter steps in strategy execution (can reuse plan cache)') |
| |
| parser.add_argument('--semantic_ret_cache', type=str, default=None, |
| help='Path to semantic-gte retrieval log (JSONL) for Stage 1') |
| parser.add_argument('--summary_file', type=str, default=None, |
| help='Path to all_session_summary.json for SemanticMemoryStore') |
| parser.add_argument('--facts_file', type=str, default=None, |
| help='Path to all_session_user_facts.json for SemanticMemoryStore') |
| parser.add_argument('--all_sessions_file', type=str, default=None, |
| help='Path to all_sessions.json for lazy episodic loading') |
| parser.add_argument('--no_save_cache', action='store_true', default=False, |
| help='Disable saving plan/reading caches to disk after the run') |
| parser.add_argument('--hier_v2', action='store_true', default=False, |
| help='Stage 1 produces candidates only: skip early-answer return, semantic keyword expansion, time_filter, and is_answerable shortcut') |
| parser.add_argument('--hier_union', action='store_true', default=False, |
| help='hier mode: union Stage-1 semantic candidates with flat-embedding top-K and run agent loop on the merged pool') |
| parser.add_argument('--hier_union_flat_k', type=int, default=20, |
| help='How many flat-embedding top-K IDs to union into the Stage-2 pool (default 20)') |
| parser.add_argument('--no_early_answer', action='store_true', default=False, |
| help='Disable Stage-1 is_answerable early-return shortcut; always proceed to Stage-2 agent loop') |
| parser.add_argument('--answer_prompt_v2', action='store_true', default=False, |
| help='Use the v2 answer prompt with explicit guidance for aggregation, temporal reasoning, knowledge updates, and absence cases.') |
| args = parser.parse_args() |
|
|
| |
| veri_reading_log_file = os.environ['reading_cache'] + f'_nchunks{args.n_chunks}' |
| qid2rel_sess_ids = {} |
| if os.path.exists(veri_reading_log_file): |
| qid2rel_sess_ids = json.load(open(veri_reading_log_file)) |
| print(f'Reading cache: {veri_reading_log_file} ({len(qid2rel_sess_ids)} cached entries)') |
|
|
| in_data = json.load(open(args.in_file)) |
| top_k = args.top_k |
| out_file = args.out_file |
|
|
| model_info = model_zoo[args.model_name] |
| deployment_name, api_version = model_info |
|
|
| existings = set() |
| retrieval_metric_list = [] |
| if os.path.exists(out_file): |
| for line in open(out_file): |
| obj = json.loads(line) |
| existings.add(obj['question_id']) |
| if 'retrieval_metric' in obj: |
| retrieval_metric_list.append(obj['retrieval_metric']) |
|
|
| out_f = open(out_file, 'a') |
|
|
| |
| qid2profiles = {} |
| with open("metadata/generated_user_profile.json") as f: |
| qid2profiles = json.load(f) |
| sess2topic = {} |
| with open("metadata/sessions_with_topic.json") as f: |
| sess2topic = json.load(f) |
|
|
| |
| |
| |
| semantic_store = None |
| episodic_store = None |
| semantic_ret_dict = None |
|
|
| if args.summary_file and args.facts_file: |
| semantic_store = SemanticMemoryStore(args.summary_file, args.facts_file) |
|
|
| if args.all_sessions_file: |
| episodic_store = EpisodicMemoryStore(args.all_sessions_file) |
|
|
| if args.semantic_ret_cache: |
| print(f"Loading semantic retrieval cache from {args.semantic_ret_cache} ...") |
| sem_ret_data = [json.loads(line) for line in open(args.semantic_ret_cache)] |
| semantic_ret_dict = {x['question_id']: x for x in sem_ret_data} |
| print(f" Loaded {len(semantic_ret_dict)} entries.") |
|
|
| retrieval_metric_list = [] |
| for di, entry in enumerate(in_data): |
| item_start_time = time.time() |
| qid, question, q_date = entry['question_id'], entry['question'], entry['question_date'] |
| q_date = entry['question_date'] |
|
|
| if qid in existings: |
| continue |
|
|
| haystack_sess_ids = entry['haystack_session_ids'] |
| haystack_topics = [sess2topic.get(sid, {}).get('category', []) for sid in haystack_sess_ids] |
| date_lookup = dict(zip(haystack_sess_ids, entry['haystack_dates'])) |
| topic_lookup = dict(zip(haystack_sess_ids, haystack_topics)) |
|
|
| |
| if episodic_store is not None: |
| |
| |
| raw_sessions = episodic_store.get_raw_sessions( |
| haystack_sess_ids, date_lookup, topic_lookup |
| ) |
| chat_history = ChatHistory(sessions=raw_sessions) |
| else: |
| chat_history = ChatHistory({ |
| "haystack_dates": entry['haystack_dates'], |
| "haystack_session_ids": entry['haystack_session_ids'], |
| "haystack_sessions": entry['haystack_sessions'], |
| "haystack_topics": haystack_topics, |
| }) |
| |
| topic_set = set() |
| for ht in haystack_topics: |
| topic_set.update(ht) |
|
|
| if args.user_profile: |
| |
| temp_qid = qid |
| if '_q_' in qid: |
| temp_qid = qid.split("_q_")[0] |
| user_profile = qid2profiles[temp_qid] |
| agent = RetrievalAgent( |
| chat_history, |
| list(topic_set), |
| user_profile=user_profile, |
| debug=args.debug, |
| vllm=args.vllm, |
| vllm_reading=args.vllm_reading, |
| tritonai=args.tritonai, |
| nvidia=args.nvidia, |
| n_chunks=args.n_chunks, |
| topic_filter=args.topic_filter, |
| no_time_filter=args.no_time_filter, |
| semantic_store=None if args.no_semantic else semantic_store, |
| episodic_store=episodic_store, |
| hier_v2=args.hier_v2, |
| hier_union=args.hier_union, |
| hier_union_flat_k=args.hier_union_flat_k, |
| no_early_answer=args.no_early_answer, |
| ) |
| else: |
| agent = RetrievalAgent( |
| chat_history, |
| list(topic_set), |
| debug=args.debug, |
| vllm=args.vllm, |
| vllm_reading=args.vllm_reading, |
| tritonai=args.tritonai, |
| nvidia=args.nvidia, |
| n_chunks=args.n_chunks, |
| topic_filter=args.topic_filter, |
| no_time_filter=args.no_time_filter, |
| semantic_store=None if args.no_semantic else semantic_store, |
| episodic_store=episodic_store, |
| hier_v2=args.hier_v2, |
| hier_union=args.hier_union, |
| hier_union_flat_k=args.hier_union_flat_k, |
| no_early_answer=args.no_early_answer, |
| ) |
| |
| try: |
| if args.mode == 'embed': |
| final_sess = embedding_search(chat_history, qid, top_k=top_k) |
| attempt_record = [{"plan": {"answer": "none", "reason": "embedding retrieval only"}}] |
| elif args.mode == 'keyword': |
| keywords = generate_keywords(question, deployment_name, api_version, |
| debug=args.debug, vllm=args.vllm, |
| tritonai=args.tritonai, nvidia=args.nvidia) |
| final_sess = keyword_search(chat_history, keywords=keywords) |
| attempt_record = [{"plan": {"answer": "none", "reason": "keyword retrieval only"}}] |
| else: |
| final_sess, attempt_record = agent.run( |
| qid, question, q_date, top_k, model_info, |
| max_loops=args.max_loops, |
| semantic_ret_dict=semantic_ret_dict, |
| haystack_sess_ids=haystack_sess_ids, |
| date_lookup=date_lookup, |
| topic_lookup=topic_lookup, |
| ) |
|
|
| if len(attempt_record) == 1 and "answer" in attempt_record[0]["plan"] and not ("none" in attempt_record[0]["plan"]["answer"].lower()): |
| answer = attempt_record[0]["plan"]["answer"] |
| token_budget = agent.get_token_budget() if args.mode == 'agent' else {} |
| wall_time_sec = time.time() - item_start_time |
|
|
| print(json.dumps({"q_idx": di, 'question_id': qid, 'question': entry['question'], |
| 'answer': answer, 'n_retrieved': len(final_sess), |
| 'wall_time_sec': round(wall_time_sec, 3)}, indent=4), flush=True) |
| print(json.dumps({"q_idx": di, 'question_id': qid, |
| 'hypothesis': answer, |
| "attempt_record": attempt_record, |
| "token_budget": token_budget, |
| "wall_time_sec": wall_time_sec}), file=out_f, flush=True) |
| else: |
| if len(final_sess) > top_k and retrieved_log_file is not None: |
| final_top_k_sess, _ = filter_out_by_embedding(final_sess, qid=qid, top_k=top_k) |
| retrieved_str = final_top_k_sess.to_prompt(granularity="session", _format="json") |
| else: |
| retrieved_str = final_sess.to_prompt(granularity="session", _format="json") |
|
|
| if args.answer_prompt_v2: |
| answer_prompt_template = ( |
| "You are answering a question using a list of chat-session transcripts between the user and an assistant.\n" |
| "\n" |
| "How to answer:\n" |
| "1. Scan ALL retrieved sessions in chronological order. The SESSION DATE on each transcript is when that conversation occurred. The Current Date below is when the question was asked, not when events happened.\n" |
| "2. Identify every session containing a candidate fact. If sessions conflict, prefer the most RECENT session that addresses the same fact (knowledge update).\n" |
| "3. For aggregation questions ('how many', 'list all', 'between X and Y'), enumerate matches across ALL relevant sessions; do not stop at the first.\n" |
| "4. For temporal queries ('last Friday', 'two weeks ago'), resolve the relative date against the SESSION DATE of the session that uses that phrase, not the Current Date.\n" |
| "5. If the retrieved sessions do NOT contain the answer, reply exactly 'Insufficient information in retrieved sessions.' Do not fabricate.\n" |
| "6. Be terse: state the direct answer first, then one short sentence citing the session date(s) you relied on.\n" |
| "\n" |
| "Chat history sessions:\n" |
| "\n" |
| "{}\n" |
| "\n" |
| "Current Date: {}\n" |
| "Question: {}\n" |
| "Answer:" |
| ) |
| else: |
| answer_prompt_template = "I will give you several chat history sessions between you and a user. Please answer the question given the information.\n\n\nChat history sessions:\n\n{}\n\nCurrent Date: {}\nQuestion: {}\nAnswer:" |
| answer_prompt = answer_prompt_template.format(retrieved_str, entry['question_date'], entry['question']) |
|
|
| completion = llm_call( |
| deployment_name, |
| api_version, |
| answer_prompt, |
| debug=args.debug, |
| vllm=args.vllm, |
| tritonai=args.tritonai, |
| nvidia=args.nvidia, |
| ) |
| answer = (completion.choices[0].message.content or "").strip() |
|
|
| if args.mode == 'agent': |
| agent._track_usage('final_answer', completion) |
| token_budget = agent.get_token_budget() if args.mode == 'agent' else {} |
|
|
| retrieval_metric = {} |
| if len(final_sess) > 0 and retrieved_log_file is not None: |
| sess_sorted = embedding_search(final_sess, qid, top_k=20) |
| sess_id_sorted = sess_sorted.get_session_ids() |
|
|
| for topk in [5, 10, 20, 30]: |
| recall_any, recall_all = evaluate_retrieval(sess_id_sorted[:topk], entry['answer_session_ids']) |
| retrieval_metric.update({ |
| 'recall_any@{}'.format(topk): recall_any, |
| 'recall_all@{}'.format(topk): recall_all |
| }) |
| retrieval_metric_list.append(retrieval_metric) |
| print_average_metrics(retrieval_metric_list) |
|
|
| print(json.dumps({"q_idx": di, 'n_prompt_tok': completion.usage.prompt_tokens, |
| 'n_completion_tok': completion.usage.completion_tokens, |
| 'hypothesis': answer, |
| 'wall_time_sec': round(time.time() - item_start_time, 3)}), flush=True) |
| print(json.dumps({"q_idx": di, 'question_id': qid, |
| 'hypothesis': answer, |
| 'n_prompt_tok': completion.usage.prompt_tokens, |
| 'n_completion_tok': completion.usage.completion_tokens, |
| "attempt_record": attempt_record, |
| "retrieved_sess_ids": final_sess.get_session_ids(), |
| "retrieval_metric": retrieval_metric, |
| "token_budget": token_budget, |
| "wall_time_sec": time.time() - item_start_time}), file=out_f, flush=True) |
| except Exception as e: |
| print(f"[ERROR] q_idx={di} qid={qid} failed: {e}", flush=True) |
| continue |
|
|
| |
|
|
| if not args.no_save_cache: |
| with open(plan_cache_file, "w") as fw: |
| json.dump(qid2plan, fw, indent=2) |
|
|
| with open(veri_reading_log_file, "w") as fw: |
| json.dump(qid2rel_sess_ids, fw, indent=2) |
|
|