Spaces:
Running
Running
| import re | |
| import requests | |
| import streamlit as st | |
| import xml.etree.ElementTree as ET | |
| from openai import OpenAI | |
| # ========================= | |
| # OpenAI Client | |
| # ========================= | |
| def get_openai_client(): | |
| api_key = st.session_state.get("OPENAI_API_KEY", "") | |
| if not api_key: | |
| raise ValueError("OpenAI API Key ใๆช่จญๅฎใงใใ") | |
| return OpenAI(api_key=api_key) | |
| def ask_llm(prompt, model="gpt-4.1-mini"): | |
| client = get_openai_client() | |
| res = client.chat.completions.create( | |
| model=model, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.2, | |
| ) | |
| return (res.choices[0].message.content or "").strip() | |
| # ========================= | |
| # Utility | |
| # ========================= | |
| def normalize_title(title: str) -> str: | |
| return " ".join((title or "").lower().strip().split()) | |
| def normalize_text(text: str) -> str: | |
| return " ".join((text or "").strip().split()) | |
| def deduplicate_papers(papers): | |
| seen = set() | |
| unique = [] | |
| for p in papers: | |
| title = normalize_title(p.get("title", "")) | |
| if not title: | |
| continue | |
| authors = p.get("authors", []) or [] | |
| first_author = authors[0].lower().strip() if authors else "" | |
| key = (title, first_author) | |
| if key not in seen: | |
| seen.add(key) | |
| unique.append(p) | |
| return unique | |
| # ========================= | |
| # arXiv Search | |
| # ========================= | |
| def parse_arxiv_response(xml_text): | |
| root = ET.fromstring(xml_text) | |
| papers = [] | |
| for entry in root.findall("{http://www.w3.org/2005/Atom}entry"): | |
| title_el = entry.find("{http://www.w3.org/2005/Atom}title") | |
| abstract_el = entry.find("{http://www.w3.org/2005/Atom}summary") | |
| date_el = entry.find("{http://www.w3.org/2005/Atom}published") | |
| authors = [] | |
| for a in entry.findall("{http://www.w3.org/2005/Atom}author"): | |
| name_el = a.find("{http://www.w3.org/2005/Atom}name") | |
| if name_el is not None and name_el.text: | |
| authors.append(name_el.text.strip()) | |
| title = title_el.text.strip() if title_el is not None and title_el.text else "" | |
| abstract = abstract_el.text.strip() if abstract_el is not None and abstract_el.text else "" | |
| date = date_el.text.strip() if date_el is not None and date_el.text else "" | |
| if title: | |
| papers.append( | |
| { | |
| "title": title, | |
| "authors": authors, | |
| "abstract": abstract, | |
| "date": date, | |
| "source": "arXiv", | |
| "venue": "", | |
| "url": "", | |
| } | |
| ) | |
| return papers | |
| def search_arxiv_once(search_query, max_results=3): | |
| url = "https://export.arxiv.org/api/query" | |
| params = { | |
| "search_query": search_query, | |
| "start": 0, | |
| "max_results": max_results, | |
| "sortBy": "relevance", | |
| "sortOrder": "descending", | |
| } | |
| res = requests.get( | |
| url, | |
| params=params, | |
| timeout=30, | |
| headers={"User-Agent": "paper-finder/0.1"}, | |
| ) | |
| res.raise_for_status() | |
| return parse_arxiv_response(res.text) | |
| def search_arxiv(query, max_results=3, debug=False): | |
| query = normalize_text(query) | |
| if not query: | |
| return [] | |
| terms = [t for t in re.split(r"\s+", query) if t] | |
| strategies = [] | |
| # ็ทฉใ้ ใซ่ฉฆใ | |
| strategies.append(f'all:{query}') | |
| strategies.append(f'all:"{query}"') | |
| strategies.append(f'ti:"{query}"') | |
| if terms: | |
| strategies.append(" AND ".join([f'all:{t}' for t in terms])) | |
| seen = set() | |
| all_papers = [] | |
| for s in strategies: | |
| try: | |
| if debug: | |
| st.write("arXiv API query:", s) | |
| papers = search_arxiv_once(s, max_results=max_results) | |
| for p in papers: | |
| key = normalize_title(p["title"]) | |
| if key not in seen: | |
| seen.add(key) | |
| all_papers.append(p) | |
| if len(all_papers) >= max_results: | |
| return all_papers[:max_results] | |
| except Exception as e: | |
| if debug: | |
| st.warning(f"arXiv query failed: {s} / {e}") | |
| return all_papers[:max_results] | |
| # ========================= | |
| # OpenAlex Search | |
| # ========================= | |
| def reconstruct_abstract(inv_index): | |
| if not inv_index: | |
| return "" | |
| words = [] | |
| for word, pos_list in inv_index.items(): | |
| for pos in pos_list: | |
| words.append((pos, word)) | |
| words.sort(key=lambda x: x[0]) | |
| return " ".join(w for _, w in words) | |
| def extract_openalex_venue(item): | |
| primary_location = item.get("primary_location") or {} | |
| source = primary_location.get("source") or {} | |
| venue = source.get("display_name", "") or "" | |
| if not venue: | |
| locations = item.get("locations") or [] | |
| for loc in locations: | |
| src = (loc or {}).get("source") or {} | |
| venue = src.get("display_name", "") or "" | |
| if venue: | |
| break | |
| if not venue: | |
| host_venue = item.get("host_venue") or {} | |
| venue = host_venue.get("display_name", "") or "" | |
| return venue | |
| def search_openalex(query, venues, max_results=3, debug=False): | |
| query = normalize_text(query) | |
| if not query or not venues: | |
| return [] | |
| url = "https://api.openalex.org/works" | |
| params = { | |
| "search": query, | |
| "per-page": 50, | |
| } | |
| try: | |
| res = requests.get( | |
| url, | |
| params=params, | |
| timeout=30, | |
| headers={"User-Agent": "paper-finder/0.1"}, | |
| ) | |
| res.raise_for_status() | |
| data = res.json() | |
| papers = [] | |
| for item in data.get("results", []): | |
| venue = extract_openalex_venue(item) | |
| if not any(v.lower() in venue.lower() for v in venues): | |
| continue | |
| authors = [] | |
| for a in item.get("authorships", []): | |
| author = a.get("author") or {} | |
| name = author.get("display_name") | |
| if name: | |
| authors.append(name) | |
| abstract = item.get("abstract_inverted_index") | |
| if isinstance(abstract, dict): | |
| abstract = reconstruct_abstract(abstract) | |
| elif not isinstance(abstract, str): | |
| abstract = "" | |
| papers.append( | |
| { | |
| "title": item.get("title", "") or "", | |
| "authors": authors, | |
| "abstract": abstract, | |
| "date": item.get("publication_date", "") or "", | |
| "source": "OpenAlex", | |
| "venue": venue, | |
| "url": item.get("id", "") or "", | |
| } | |
| ) | |
| if len(papers) >= max_results: | |
| break | |
| if debug: | |
| st.write("OpenAlex matched papers:", len(papers)) | |
| return papers | |
| except Exception as e: | |
| if debug: | |
| st.warning(f"OpenAlex search failed: {e}") | |
| return [] | |
| # ========================= | |
| # LLM Utilities | |
| # ========================= | |
| def normalize_keyword_for_search(keyword, model): | |
| prompt = f""" | |
| ใใชใใฏๅญฆ่ก่ซๆๆค็ดขใขใทในใฟใณใใงใใ | |
| ไปฅไธใฎใฆใผใถใผๅ ฅๅใใarXivใOpenAlexใงๆค็ดขใใใใ่ฑ่ชใฎ็ญใๆค็ดขใฏใจใชใซๅคๆใใฆใใ ใใใ | |
| ใซใผใซ: | |
| - ๅบๅใฏ่ฑ่ชใฎๆค็ดขใฏใจใช1ใคใ ใ | |
| - ไฝ่จใช่ชฌๆใฏไธ่ฆ | |
| - ๆฅๆฌ่ชๅ ฅๅใชใ่ช็ถใช่ฑ่ชใฎ็ ็ฉถใญใผใฏใผใใธๅคๆ | |
| - ่ฑ่ชๅ ฅๅใชใๆๅณใไฟใฃใฆ็ฐกๆฝใซๆดๅฝข | |
| - 2่ชใใ8่ช็จๅบฆใๆใพใใ | |
| - ไธ่ฆใช่จๅทใฏๅ ฅใใชใ | |
| input: {keyword} | |
| """ | |
| return normalize_text(ask_llm(prompt, model)) | |
| def paraphrase_query(keyword, model): | |
| prompt = f""" | |
| ๆฌกใฎ็ ็ฉถใใใใฏใใ่ฑ่ชใฎ่ซๆๆค็ดขใฏใจใชใจใใฆ่จใๆใใฆใใ ใใใ | |
| ๅบๅใฏ็ญใ่ฑ่ชใฏใจใชใ1ใคใ ใใซใใฆใใ ใใใ | |
| ่ชฌๆใฏไธ่ฆใงใใ | |
| topic: {keyword} | |
| """ | |
| return normalize_text(ask_llm(prompt, model)) | |
| def classify_field(keyword, model): | |
| prompt = f""" | |
| ๆฌกใฎ็ ็ฉถใใใใฏใไธปใซๅฑใใๅ้ใใไปฅไธใใ1ใคใ ใ้ธใใงใใ ใใใ | |
| ๅ่ฃ: | |
| ML | |
| NLP | |
| CV | |
| OTHER | |
| ็ ็ฉถใใใใฏ: | |
| {keyword} | |
| ๅคๅฎใซใผใซ: | |
| - ๆฉๆขฐๅญฆ็ฟๅ จ่ฌใๆ้ฉๅใ่กจ็พๅญฆ็ฟใๅผทๅๅญฆ็ฟใ็ๆใขใใซใชใฉใฏ ML | |
| - ่ช็ถ่จ่ชๅฆ็ใๅฏพ่ฉฑใ็ฟป่จณใ่ฆ็ดใLLMใRAG ใชใฉใฏ NLP | |
| - ็ปๅใๅ็ปใ็ฉไฝๆคๅบใใปใฐใกใณใใผใทใงใณใ3D vision ใชใฉใฏ CV | |
| - ไธ่จใซๆ็ขบใซๅฝใฆใฏใพใใชใใใฐ OTHER | |
| ๅบๅใฏใฉใใซ1ใคใ ใใซใใฆใใ ใใใ | |
| """ | |
| return ask_llm(prompt, model).strip().upper() | |
| def summarize_paper(title, abstract, model, venue=""): | |
| prompt = f""" | |
| ๆฌกใฎ่ซๆใ็ฐกๆฝใซๆฅๆฌ่ชใง่งฃ่ชฌใใฆใใ ใใใ | |
| Title: | |
| {title} | |
| Venue: | |
| {venue} | |
| Abstract: | |
| {abstract} | |
| ๅบๅๅฝขๅผ: | |
| - ่ฆ็ด | |
| - ไฝใๆฐใใใ | |
| - ใฉใใชไบบใซใใใใใ | |
| """ | |
| return ask_llm(prompt, model) | |
| def select_best_papers(papers, keyword, model, top_k=3): | |
| if not papers: | |
| return [] | |
| if len(papers) <= top_k: | |
| return papers[:top_k] | |
| text = "" | |
| for i, p in enumerate(papers): | |
| text += f""" | |
| Paper {i} | |
| Title: {p.get("title", "")} | |
| Venue: {p.get("venue", "")} | |
| Abstract: {p.get("abstract", "")} | |
| """ | |
| prompt = f""" | |
| ๆฌกใฎ่ซๆใชในใใใใ็ ็ฉถใใใใฏใ{keyword}ใใซๆใ้ข้ฃใใใ้่ฆๅบฆใ้ซใ่ซๆใ {top_k} ๆฌ้ธใใงใใ ใใใ | |
| ๅฟ ใ็ฐใชใ่ซๆใ้ธใใงใใ ใใใ | |
| {text} | |
| ๅบๅๅฝขๅผ: | |
| 0,2,5 | |
| """ | |
| try: | |
| res = ask_llm(prompt, model) | |
| ids = [] | |
| for x in res.split(","): | |
| x = x.strip() | |
| if x.isdigit(): | |
| ids.append(int(x)) | |
| ids = list(dict.fromkeys(ids)) | |
| results = [] | |
| seen_titles = set() | |
| for i in ids: | |
| if 0 <= i < len(papers): | |
| title_key = normalize_title(papers[i].get("title", "")) | |
| if title_key and title_key not in seen_titles: | |
| results.append(papers[i]) | |
| seen_titles.add(title_key) | |
| if len(results) >= top_k: | |
| break | |
| if results: | |
| return results[:top_k] | |
| except Exception: | |
| pass | |
| return papers[:top_k] | |
| # ========================= | |
| # Streamlit UI | |
| # ========================= | |
| st.set_page_config(page_title="Paper Finder", layout="wide") | |
| st.title("๐ Paper Finder") | |
| st.sidebar.header("Settings") | |
| openai_api_key = st.sidebar.text_input("OpenAI API Key", type="password") | |
| if openai_api_key: | |
| st.session_state["OPENAI_API_KEY"] = openai_api_key | |
| model = st.sidebar.selectbox( | |
| "Model", | |
| ["gpt-4.1-mini", "gpt-4.1", "gpt-4o-mini"], | |
| index=0, | |
| ) | |
| debug_mode = st.sidebar.checkbox("Debug mode", value=True) | |
| keyword = st.text_input("Research Keyword") | |
| if st.button("Search Papers"): | |
| if not st.session_state.get("OPENAI_API_KEY"): | |
| st.error("OpenAI API Key ใๅ ฅๅใใฆใใ ใใใ") | |
| st.stop() | |
| if not keyword.strip(): | |
| st.warning("Research Keyword ใๅ ฅๅใใฆใใ ใใใ") | |
| st.stop() | |
| paper_list = [] | |
| st.write("### Step0 Query Normalization") | |
| try: | |
| normalized_keyword = normalize_keyword_for_search(keyword, model) | |
| except Exception as e: | |
| st.error(f"ๆค็ดขใฏใจใชๆญฃ่ฆๅใซๅคฑๆใใพใใ: {e}") | |
| st.stop() | |
| st.write("**Input keyword:**", keyword) | |
| st.write("**Normalized English query:**", normalized_keyword) | |
| st.write("### Step1 arXiv search") | |
| papers_step1 = search_arxiv(normalized_keyword, max_results=10, debug=debug_mode) | |
| paper_list.extend(papers_step1) | |
| st.write(f"found {len(papers_step1)} papers") | |
| st.write("### Step2 Query Paraphrase") | |
| try: | |
| paraphrased = paraphrase_query(normalized_keyword, model) | |
| except Exception as e: | |
| paraphrased = normalized_keyword | |
| if debug_mode: | |
| st.warning(f"Query paraphrase failed: {e}") | |
| st.write("**Paraphrased query:**", paraphrased) | |
| papers_step2 = search_arxiv(paraphrased, max_results=10, debug=debug_mode) | |
| paper_list.extend(papers_step2) | |
| st.write(f"found {len(papers_step2)} papers") | |
| st.write("### Step3 Field Classification") | |
| try: | |
| field = classify_field(keyword, model) | |
| except Exception as e: | |
| field = "OTHER" | |
| if debug_mode: | |
| st.warning(f"Field classification failed: {e}") | |
| st.write("**field:**", field) | |
| if field == "ML": | |
| venues = ["ICML", "ICLR", "NeurIPS"] | |
| elif field == "NLP": | |
| venues = ["ACL", "EMNLP", "NAACL", "AACL"] | |
| elif field == "CV": | |
| venues = ["CVPR", "ICCV", "ECCV", "SIGGRAPH"] | |
| else: | |
| venues = [] | |
| papers_step3 = [] | |
| if venues: | |
| st.write("### Step4 Top-conference Search") | |
| papers_step3 = search_openalex(normalized_keyword, venues, max_results=10, debug=debug_mode) | |
| paper_list.extend(papers_step3) | |
| st.write(f"found {len(papers_step3)} papers") | |
| paper_list = deduplicate_papers(paper_list) | |
| st.write("### Total candidate papers:", len(paper_list)) | |
| if debug_mode and paper_list: | |
| with st.expander("Candidate Papers"): | |
| for i, p in enumerate(paper_list): | |
| st.write( | |
| f"{i}. {p.get('title', '')} | venue={p.get('venue', '') or '-'} | source={p.get('source', '')}" | |
| ) | |
| if not paper_list: | |
| st.error("่ซๆใ่ฆใคใใใพใใใงใใใใใไธ่ฌ็ใช่กจ็พใๅฅใฎใญใผใฏใผใใง่ฉฆใใฆใใ ใใใ") | |
| st.stop() | |
| st.write("### Selecting best papers") | |
| best = select_best_papers(paper_list, keyword, model, top_k=3) | |
| if not best: | |
| st.warning("ๆจ่ฆ่ซๆใฎ้ธๅฎใซๅคฑๆใใใใใๅ่ฃ่ซๆใใใฎใพใพ่กจ็คบใใพใใ") | |
| best = paper_list[:3] | |
| st.write("## Recommended Papers") | |
| for p in best: | |
| abstract = p.get("abstract", "") or "" | |
| venue = p.get("venue", "") or "-" | |
| try: | |
| summary = summarize_paper( | |
| title=p.get("title", ""), | |
| abstract=abstract, | |
| model=model, | |
| venue=venue, | |
| ) if abstract else "ใขใในใใฉใฏใใๅๅพใงใใชใใฃใใใใ่ฆ็ดใ็ๆใงใใพใใใงใใใ" | |
| except Exception as e: | |
| summary = f"่ฆ็ด็ๆใซๅคฑๆใใพใใ: {e}" | |
| st.markdown("---") | |
| st.subheader(p.get("title", "Untitled")) | |
| st.write("**Explanation:**") | |
| st.write(summary) | |
| st.write("**Authors:**", ", ".join(p.get("authors", [])) if p.get("authors") else "-") | |
| st.write("**Date:**", p.get("date", "") or "-") | |
| st.write("**Source:**", p.get("source", "") or "-") | |
| st.write("**Venue:**", venue) | |
| st.write("**Abstract:**") | |
| st.write(abstract if abstract else "ใขใในใใฉใฏใใชใ") | |