import re import requests import streamlit as st import xml.etree.ElementTree as ET from openai import OpenAI # ========================= # OpenAI Client # ========================= def get_openai_client(): api_key = st.session_state.get("OPENAI_API_KEY", "") if not api_key: raise ValueError("OpenAI API Key が未設定です。") return OpenAI(api_key=api_key) def ask_llm(prompt, model="gpt-4.1-mini"): client = get_openai_client() res = client.chat.completions.create( model=model, messages=[{"role": "user", "content": prompt}], temperature=0.2, ) return (res.choices[0].message.content or "").strip() # ========================= # Utility # ========================= def normalize_title(title: str) -> str: return " ".join((title or "").lower().strip().split()) def normalize_text(text: str) -> str: return " ".join((text or "").strip().split()) def deduplicate_papers(papers): seen = set() unique = [] for p in papers: title = normalize_title(p.get("title", "")) if not title: continue authors = p.get("authors", []) or [] first_author = authors[0].lower().strip() if authors else "" key = (title, first_author) if key not in seen: seen.add(key) unique.append(p) return unique # ========================= # arXiv Search # ========================= def parse_arxiv_response(xml_text): root = ET.fromstring(xml_text) papers = [] for entry in root.findall("{http://www.w3.org/2005/Atom}entry"): title_el = entry.find("{http://www.w3.org/2005/Atom}title") abstract_el = entry.find("{http://www.w3.org/2005/Atom}summary") date_el = entry.find("{http://www.w3.org/2005/Atom}published") authors = [] for a in entry.findall("{http://www.w3.org/2005/Atom}author"): name_el = a.find("{http://www.w3.org/2005/Atom}name") if name_el is not None and name_el.text: authors.append(name_el.text.strip()) title = title_el.text.strip() if title_el is not None and title_el.text else "" abstract = abstract_el.text.strip() if abstract_el is not None and abstract_el.text else "" date = date_el.text.strip() if date_el is not None and date_el.text else "" if title: papers.append( { "title": title, "authors": authors, "abstract": abstract, "date": date, "source": "arXiv", "venue": "", "url": "", } ) return papers def search_arxiv_once(search_query, max_results=3): url = "https://export.arxiv.org/api/query" params = { "search_query": search_query, "start": 0, "max_results": max_results, "sortBy": "relevance", "sortOrder": "descending", } res = requests.get( url, params=params, timeout=30, headers={"User-Agent": "paper-finder/0.1"}, ) res.raise_for_status() return parse_arxiv_response(res.text) def search_arxiv(query, max_results=3, debug=False): query = normalize_text(query) if not query: return [] terms = [t for t in re.split(r"\s+", query) if t] strategies = [] # 緩い順に試す strategies.append(f'all:{query}') strategies.append(f'all:"{query}"') strategies.append(f'ti:"{query}"') if terms: strategies.append(" AND ".join([f'all:{t}' for t in terms])) seen = set() all_papers = [] for s in strategies: try: if debug: st.write("arXiv API query:", s) papers = search_arxiv_once(s, max_results=max_results) for p in papers: key = normalize_title(p["title"]) if key not in seen: seen.add(key) all_papers.append(p) if len(all_papers) >= max_results: return all_papers[:max_results] except Exception as e: if debug: st.warning(f"arXiv query failed: {s} / {e}") return all_papers[:max_results] # ========================= # OpenAlex Search # ========================= def reconstruct_abstract(inv_index): if not inv_index: return "" words = [] for word, pos_list in inv_index.items(): for pos in pos_list: words.append((pos, word)) words.sort(key=lambda x: x[0]) return " ".join(w for _, w in words) def extract_openalex_venue(item): primary_location = item.get("primary_location") or {} source = primary_location.get("source") or {} venue = source.get("display_name", "") or "" if not venue: locations = item.get("locations") or [] for loc in locations: src = (loc or {}).get("source") or {} venue = src.get("display_name", "") or "" if venue: break if not venue: host_venue = item.get("host_venue") or {} venue = host_venue.get("display_name", "") or "" return venue def search_openalex(query, venues, max_results=3, debug=False): query = normalize_text(query) if not query or not venues: return [] url = "https://api.openalex.org/works" params = { "search": query, "per-page": 50, } try: res = requests.get( url, params=params, timeout=30, headers={"User-Agent": "paper-finder/0.1"}, ) res.raise_for_status() data = res.json() papers = [] for item in data.get("results", []): venue = extract_openalex_venue(item) if not any(v.lower() in venue.lower() for v in venues): continue authors = [] for a in item.get("authorships", []): author = a.get("author") or {} name = author.get("display_name") if name: authors.append(name) abstract = item.get("abstract_inverted_index") if isinstance(abstract, dict): abstract = reconstruct_abstract(abstract) elif not isinstance(abstract, str): abstract = "" papers.append( { "title": item.get("title", "") or "", "authors": authors, "abstract": abstract, "date": item.get("publication_date", "") or "", "source": "OpenAlex", "venue": venue, "url": item.get("id", "") or "", } ) if len(papers) >= max_results: break if debug: st.write("OpenAlex matched papers:", len(papers)) return papers except Exception as e: if debug: st.warning(f"OpenAlex search failed: {e}") return [] # ========================= # LLM Utilities # ========================= def normalize_keyword_for_search(keyword, model): prompt = f""" あなたは学術論文検索アシスタントです。 以下のユーザー入力を、arXivやOpenAlexで検索しやすい英語の短い検索クエリに変換してください。 ルール: - 出力は英語の検索クエリ1つだけ - 余計な説明は不要 - 日本語入力なら自然な英語の研究キーワードへ変換 - 英語入力なら意味を保って簡潔に整形 - 2語から8語程度が望ましい - 不要な記号は入れない input: {keyword} """ return normalize_text(ask_llm(prompt, model)) def paraphrase_query(keyword, model): prompt = f""" 次の研究トピックを、英語の論文検索クエリとして言い換えてください。 出力は短い英語クエリを1つだけにしてください。 説明は不要です。 topic: {keyword} """ return normalize_text(ask_llm(prompt, model)) def classify_field(keyword, model): prompt = f""" 次の研究トピックが主に属する分野を、以下から1つだけ選んでください。 候補: ML NLP CV OTHER 研究トピック: {keyword} 判定ルール: - 機械学習全般、最適化、表現学習、強化学習、生成モデルなどは ML - 自然言語処理、対話、翻訳、要約、LLM、RAG などは NLP - 画像、動画、物体検出、セグメンテーション、3D vision などは CV - 上記に明確に当てはまらなければ OTHER 出力はラベル1つだけにしてください。 """ return ask_llm(prompt, model).strip().upper() def summarize_paper(title, abstract, model, venue=""): prompt = f""" 次の論文を簡潔に日本語で解説してください。 Title: {title} Venue: {venue} Abstract: {abstract} 出力形式: - 要約 - 何が新しいか - どんな人におすすめか """ return ask_llm(prompt, model) def select_best_papers(papers, keyword, model, top_k=3): if not papers: return [] if len(papers) <= top_k: return papers[:top_k] text = "" for i, p in enumerate(papers): text += f""" Paper {i} Title: {p.get("title", "")} Venue: {p.get("venue", "")} Abstract: {p.get("abstract", "")} """ prompt = f""" 次の論文リストから、研究トピック「{keyword}」に最も関連があり重要度が高い論文を {top_k} 本選んでください。 必ず異なる論文を選んでください。 {text} 出力形式: 0,2,5 """ try: res = ask_llm(prompt, model) ids = [] for x in res.split(","): x = x.strip() if x.isdigit(): ids.append(int(x)) ids = list(dict.fromkeys(ids)) results = [] seen_titles = set() for i in ids: if 0 <= i < len(papers): title_key = normalize_title(papers[i].get("title", "")) if title_key and title_key not in seen_titles: results.append(papers[i]) seen_titles.add(title_key) if len(results) >= top_k: break if results: return results[:top_k] except Exception: pass return papers[:top_k] # ========================= # Streamlit UI # ========================= st.set_page_config(page_title="Paper Finder", layout="wide") st.title("📚 Paper Finder") st.sidebar.header("Settings") openai_api_key = st.sidebar.text_input("OpenAI API Key", type="password") if openai_api_key: st.session_state["OPENAI_API_KEY"] = openai_api_key model = st.sidebar.selectbox( "Model", ["gpt-4.1-mini", "gpt-4.1", "gpt-4o-mini"], index=0, ) debug_mode = st.sidebar.checkbox("Debug mode", value=True) keyword = st.text_input("Research Keyword") if st.button("Search Papers"): if not st.session_state.get("OPENAI_API_KEY"): st.error("OpenAI API Key を入力してください。") st.stop() if not keyword.strip(): st.warning("Research Keyword を入力してください。") st.stop() paper_list = [] st.write("### Step0 Query Normalization") try: normalized_keyword = normalize_keyword_for_search(keyword, model) except Exception as e: st.error(f"検索クエリ正規化に失敗しました: {e}") st.stop() st.write("**Input keyword:**", keyword) st.write("**Normalized English query:**", normalized_keyword) st.write("### Step1 arXiv search") papers_step1 = search_arxiv(normalized_keyword, max_results=10, debug=debug_mode) paper_list.extend(papers_step1) st.write(f"found {len(papers_step1)} papers") st.write("### Step2 Query Paraphrase") try: paraphrased = paraphrase_query(normalized_keyword, model) except Exception as e: paraphrased = normalized_keyword if debug_mode: st.warning(f"Query paraphrase failed: {e}") st.write("**Paraphrased query:**", paraphrased) papers_step2 = search_arxiv(paraphrased, max_results=10, debug=debug_mode) paper_list.extend(papers_step2) st.write(f"found {len(papers_step2)} papers") st.write("### Step3 Field Classification") try: field = classify_field(keyword, model) except Exception as e: field = "OTHER" if debug_mode: st.warning(f"Field classification failed: {e}") st.write("**field:**", field) if field == "ML": venues = ["ICML", "ICLR", "NeurIPS"] elif field == "NLP": venues = ["ACL", "EMNLP", "NAACL", "AACL"] elif field == "CV": venues = ["CVPR", "ICCV", "ECCV", "SIGGRAPH"] else: venues = [] papers_step3 = [] if venues: st.write("### Step4 Top-conference Search") papers_step3 = search_openalex(normalized_keyword, venues, max_results=10, debug=debug_mode) paper_list.extend(papers_step3) st.write(f"found {len(papers_step3)} papers") paper_list = deduplicate_papers(paper_list) st.write("### Total candidate papers:", len(paper_list)) if debug_mode and paper_list: with st.expander("Candidate Papers"): for i, p in enumerate(paper_list): st.write( f"{i}. {p.get('title', '')} | venue={p.get('venue', '') or '-'} | source={p.get('source', '')}" ) if not paper_list: st.error("論文が見つかりませんでした。より一般的な表現や別のキーワードで試してください。") st.stop() st.write("### Selecting best papers") best = select_best_papers(paper_list, keyword, model, top_k=3) if not best: st.warning("推薦論文の選定に失敗したため、候補論文をそのまま表示します。") best = paper_list[:3] st.write("## Recommended Papers") for p in best: abstract = p.get("abstract", "") or "" venue = p.get("venue", "") or "-" try: summary = summarize_paper( title=p.get("title", ""), abstract=abstract, model=model, venue=venue, ) if abstract else "アブストラクトが取得できなかったため、要約を生成できませんでした。" except Exception as e: summary = f"要約生成に失敗しました: {e}" st.markdown("---") st.subheader(p.get("title", "Untitled")) st.write("**Explanation:**") st.write(summary) st.write("**Authors:**", ", ".join(p.get("authors", [])) if p.get("authors") else "-") st.write("**Date:**", p.get("date", "") or "-") st.write("**Source:**", p.get("source", "") or "-") st.write("**Venue:**", venue) st.write("**Abstract:**") st.write(abstract if abstract else "アブストラクトなし")