from __future__ import annotations
import base64
import os
from pathlib import Path
from typing import List
import pandas as pd
import networkx as nx
import streamlit as st
import plotly.express as px
import plotly.graph_objects as go
from pyvis.network import Network
import streamlit.components.v1 as components
HF_REPO_ID = os.environ.get("HF_REPO_ID", "")
def csv_download_link(data: bytes, filename: str, label: str) -> None:
"""st.download_button 대신 base64 HTML 링크로 다운로드 — 서버 연결 불필요."""
b64 = base64.b64encode(data).decode()
st.markdown(
f''
f'{label}',
unsafe_allow_html=True,
)
HF_TOKEN = os.environ.get("HF_TOKEN", "")
st.set_page_config(page_title="CitationHub", page_icon="📚", layout="wide")
ALLOWED_INTENTS = [
"background","uses","similarities","motivation",
"differences","future_work","extends",
]
INTENT_COLORS = {
"background":"#94a3b8","uses":"#22c55e","similarities":"#3b82f6",
"motivation":"#f59e0b","differences":"#ef4444",
"future_work":"#8b5cf6","extends":"#06b6d4",
}
NODE_COLORS = {
"seed_paper":"#111827","citing_paper":"#dbeafe","citation_event":"#fde68a",
"journal":"#ede9fe","author":"#fee2e2","affiliation":"#fae8ff",
"city":"#cffafe","country":"#ffedd5","field":"#e0e7ff","intent":"#dcfce7",
}
NODE_TYPE_COLORS = {
"seed_paper":"#111827","citing_paper":"#3b82f6","citation_event":"#f59e0b",
"journal":"#8b5cf6","author":"#ef4444","affiliation":"#ec4899",
"city":"#06b6d4","country":"#f97316","field":"#6366f1","intent":"#22c55e",
}
DEFAULT_DATA_DIR = Path(os.environ.get(
"CITATIONHUB_DATA_DIR",
r"C:\Users\user\OneDrive\바탕 화면\Citehub_huggingface\data",
))
def fmt_num(x):
try: return f"{int(x):,}"
except: return "-"
def _hf_download(filename: str) -> str:
from huggingface_hub import hf_hub_download
return hf_hub_download(
repo_id=HF_REPO_ID, repo_type="dataset",
filename=f"data/{filename}", token=HF_TOKEN or None,
)
def _read(filename: str, data_dir: Path | None = None) -> pd.DataFrame:
if HF_REPO_ID:
return pd.read_parquet(_hf_download(filename))
return pd.read_parquet(data_dir / filename)
def plotly_network_fig(
nodes_df: pd.DataFrame,
edges_df: pd.DataFrame,
title: str = "",
height: int = 750,
seed_node_ids: list | None = None,
) -> go.Figure:
"""SVG 기반 Plotly 네트워크 그래프 — 확대해도 선명."""
G = nx.Graph()
node_meta: dict = {}
for _, row in nodes_df.iterrows():
nid = str(row["node_id"])
G.add_node(nid)
node_meta[nid] = row
for _, row in edges_df.iterrows():
s, t = str(row["source"]), str(row["target"])
if s in node_meta and t in node_meta:
G.add_edge(s, t, edge_type=row.get("edge_type", ""))
if len(G.nodes) == 0:
return go.Figure()
k = max(1.5, 3.0 / (len(G.nodes) ** 0.4))
pos = nx.spring_layout(G, seed=42, k=k, iterations=60)
ex, ey = [], []
for src, tgt in G.edges():
x0, y0 = pos.get(src, (0, 0))
x1, y1 = pos.get(tgt, (0, 0))
ex += [x0, x1, None]
ey += [y0, y1, None]
traces: list[go.BaseTraceType] = [
go.Scatter(
x=ex, y=ey, mode="lines",
line=dict(width=0.8, color="#cbd5e1"),
hoverinfo="none", showlegend=False,
)
]
for ntype, color in NODE_TYPE_COLORS.items():
subset = nodes_df[nodes_df["node_type"] == ntype]
if subset.empty:
continue
xs, ys, hovers, texts = [], [], [], []
for _, row in subset.iterrows():
nid = str(row["node_id"])
if nid not in pos:
continue
x, y = pos[nid]
xs.append(x); ys.append(y)
label = str(row.get("label", ""))[:50]
texts.append(label if ntype == "seed_paper" else "")
hovers.append(
f"{label}
"
f"Type: {ntype}
"
f"DOI: {row.get('doi','') or '-'}
"
f"Pub: {row.get('publication_name','') or '-'}
"
f"Group: {row.get('group','') or '-'}"
)
is_seed = ntype == "seed_paper"
traces.append(go.Scatter(
x=xs, y=ys,
mode="markers+text" if is_seed else "markers",
text=texts, textposition="top center",
hovertext=hovers, hoverinfo="text",
name=ntype,
marker=dict(
size=20 if is_seed else 10,
color=color,
line=dict(width=1.5 if is_seed else 0.5, color="white"),
symbol="circle",
),
))
fig = go.Figure(data=traces)
fig.update_layout(
title=dict(text=title, font=dict(size=14)),
showlegend=True,
legend=dict(title="Node type", itemsizing="constant"),
hovermode="closest",
height=height,
margin=dict(l=0, r=0, t=40 if title else 10, b=0),
paper_bgcolor="white",
plot_bgcolor="#f8fafc",
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
)
return fig
def plotly_ontology_fig(height: int = 820) -> go.Figure:
"""CitationHub 온톨로지 구조 — Plotly SVG. 각 노드에 속성값 표시."""
NODE_PROPS = {
"seed_paper": "doi · title · journal\nauthor · affiliation\ncountry · field · citedby_count",
"citation_event": "event_id · citing_year\nprimary_intent · context\nis_influential",
"citing_paper": "doi · title\nyear · venue · oa_pdf",
"intent": "background · uses\nsimilarities · motivation\ndifferences · future_work · extends",
"journal": "journal_name",
"author": "author_name · author_id",
"affiliation": "affiliation_name",
"city": "city_name",
"country": "country_name",
"field": "field_name",
}
node_defs = [
("seed", "Top5PctCitedPaper", "seed_paper"),
("event", "CitationEvent", "citation_event"),
("citing", "CitingPaper", "citing_paper"),
("intent", "Intent", "intent"),
("journal", "Journal", "journal"),
("author", "Author", "author"),
("affiliation", "Affiliation", "affiliation"),
("city", "City", "city"),
("country", "Country", "country"),
("field", "Field", "field"),
]
edge_defs = [
("event","citing","hasCitingPaper"), ("event","seed","hasCitedPaper"),
("event","intent","hasPrimaryIntent"), ("seed","journal","publishedInJournal"),
("seed","author","hasAuthor"), ("seed","affiliation","hasAffiliation"),
("seed","city","locatedInCity"), ("seed","country","locatedInCountry"),
("seed","field","belongsToField"),
]
G = nx.DiGraph()
for nid, _, _ in node_defs:
G.add_node(nid)
for s, t, _ in edge_defs:
G.add_edge(s, t)
pos = nx.spring_layout(G, seed=7, k=2.5, iterations=80)
ex, ey = [], []
ann = []
for s, t, lbl in edge_defs:
x0, y0 = pos[s]; x1, y1 = pos[t]
ex += [x0, x1, None]; ey += [y0, y1, None]
mx, my = (x0+x1)/2, (y0+y1)/2
ann.append(dict(
x=mx, y=my, text=f"{lbl}",
showarrow=False, font=dict(size=9, color="#64748b"),
bgcolor="rgba(255,255,255,0.75)",
))
traces: list[go.BaseTraceType] = [
go.Scatter(x=ex, y=ey, mode="lines",
line=dict(width=1.2, color="#94a3b8"),
hoverinfo="none", showlegend=False)
]
for nid, label, ntype in node_defs:
x, y = pos[nid]
color = NODE_TYPE_COLORS.get(ntype, "#94a3b8")
props = NODE_PROPS.get(ntype, "")
traces.append(go.Scatter(
x=[x], y=[y], mode="markers+text",
text=[f"{label}"], textposition="top center",
hoverinfo="text",
hovertext=(f"{label}
Type: {ntype}
"
+ props.replace("\n", "
")),
name=label, showlegend=False,
marker=dict(size=24, color=color,
line=dict(width=1.5, color="white")),
textfont=dict(size=11, color="#1e293b"),
))
if props:
prop_html = props.replace("\n", "
")
ann.append(dict(
x=x, y=y,
text=f"{prop_html}",
showarrow=False,
xanchor="center",
yanchor="top",
yshift=-22,
font=dict(size=8, color="#64748b"),
bgcolor="rgba(248,250,252,0.85)",
borderpad=2,
))
fig = go.Figure(data=traces)
fig.update_layout(
showlegend=False, hovermode="closest", height=height,
annotations=ann,
margin=dict(l=10, r=10, t=20, b=10),
paper_bgcolor="white", plot_bgcolor="#f8fafc",
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
)
return fig
def inject_fullscreen(html: str) -> str:
extra = """