File size: 11,051 Bytes
8bd78d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
FastAPI entrypoint for an ADK agent using **in-memory** session and artifact storage.

Intended for **single-instance** deployment (e.g. one Cloud Run instance with min/max
instances = 1). State is lost on restart and is not shared across replicas.

Swap points for your project:
  - Import: replace `paper_agent` / the fallback import with your real agent symbol
    (this repo exposes `root_agent` in `research_explainer.agent`).

Images in the JSON response are **data URLs** (`data:image/png;base64,...`) loaded from
the in-memory artifact store after the run, so a browser or frontend can render them
without GCS.

Session TTL: set ``SESSION_TTL_SECONDS`` (seconds of inactivity). ``0`` disables expiry.
After TTL, the session is deleted and recreated on the next request. Only **session-scoped**
artifacts are removed on expiry; ``user:`` namespaced artifacts are left intact so other
sessions for the same ``user_id`` are not affected.
"""

from __future__ import annotations

import base64
import logging
import os
from typing import Any, Iterable

import dotenv

dotenv.load_dotenv()

from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from google.adk.artifacts import InMemoryArtifactService
from google.adk.events.event import Event
from google.adk.runners import Runner
from google.adk.sessions import InMemorySessionService
from google.genai import types
from pydantic import BaseModel
import uvicorn

# -----------------------------------------------------------------------------
# SWAP: import your root ADK agent here.
# Example for this repository:
#   from research_explainer.agent import root_agent as paper_agent
# -----------------------------------------------------------------------------
try:
    from agent import paper_agent
except ImportError:  # pragma: no cover - convenience for this repo layout
    from research_explainer.agent import root_agent as paper_agent

logger = logging.getLogger(__name__)

# Max PDF size for UploadFile (bytes).
MAX_PDF_BYTES = int(os.environ.get("MAX_PDF_BYTES", str(25 * 1024 * 1024)))

APP_NAME = os.environ.get("ADK_APP_NAME", "research_explainer")
DEFAULT_USER_ID = os.environ.get("ADK_USER_ID", "web")
# Set RUNNING_LOCALLY=1 for verbose session logging (similar to local dev flags).
RUNNING_LOCALLY = os.environ.get("RUNNING_LOCALLY", "").lower() in (
    "1",
    "true",
    "yes",
)

artifact_service = InMemoryArtifactService()
session_service = InMemorySessionService()

runner = Runner(
    app_name=APP_NAME,
    agent=paper_agent,
    session_service=session_service,
    artifact_service=artifact_service,
)


async def resolve_session(
    user_id: str,
    session_id: str,
    *,
    initial_state: dict[str, Any] | None = None,
) -> None:
    """
    Load an existing session or create one with the given id.

    Use `initial_state` when you need to seed session-scoped state on first creation
    (e.g. tool flags). Omitted here by default; extend the call site if your app needs it.
    """
    sess = await session_service.get_session(
        app_name=APP_NAME,
        user_id=user_id,
        session_id=session_id,
    )
    if sess is not None:
        if RUNNING_LOCALLY:
            logger.info(
                "Session already exists: app=%r user=%r session=%r",
                APP_NAME,
                user_id,
                session_id,
            )
        return

    await session_service.create_session(
        app_name=APP_NAME,
        user_id=user_id,
        session_id=session_id,
        state=initial_state,
    )
    logger.info(
        "New session created: app=%r user=%r session=%r",
        APP_NAME,
        user_id,
        session_id,
    )


app = FastAPI(title="Research Explainer API", version="1.0.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=False,
    allow_methods=["*"],
    allow_headers=["*"],
)


def _gather_text_for_response(events: Iterable[Event]) -> str:
    """Collects user-visible assistant text from streamed events.

    Do not skip events just because they also include tool calls/responses; the model
    often emits explanation text in the same turn as ``function_call`` parts. Skipping
    those events previously dropped the entire explanation while images still appeared.
    """
    final_chunks: list[str] = []
    assistant_chunks: list[str] = []

    for event in events:
        if event.partial:
            continue
        if not event.content or not event.content.parts:
            continue
        # User turn events can appear in the stream; only aggregate assistant output.
        if event.author == "user":
            continue

        pieces: list[str] = []
        for part in event.content.parts:
            if part.text:
                pieces.append(part.text)
        segment = "".join(pieces).strip()
        if not segment:
            continue

        assistant_chunks.append(segment)
        if event.is_final_response():
            final_chunks.append(segment)

    if final_chunks:
        return "\n\n".join(final_chunks)
    if assistant_chunks:
        return "\n\n".join(assistant_chunks)
    return ""


async def _collect_images_as_data_urls(
    events: Iterable[Event],
    *,
    app_name: str,
    user_id: str,
    session_id: str,
) -> list[str]:
    """
    Loads image artifacts touched during this run from the in-memory artifact service
    and returns them as data URLs for the frontend.
    """
    seen: set[tuple[str, int]] = set()
    ordered: list[str] = []

    for event in events:
        if not event.actions or not event.actions.artifact_delta:
            continue
        for filename, version in event.actions.artifact_delta.items():
            key = (filename, version)
            if key in seen:
                continue
            seen.add(key)

            load_session_id = None if filename.startswith("user:") else session_id
            part = await artifact_service.load_artifact(
                app_name=app_name,
                user_id=user_id,
                filename=filename,
                session_id=load_session_id,
                version=version,
            )
            if not part or not part.inline_data or not part.inline_data.data:
                continue
            mime = (part.inline_data.mime_type or "application/octet-stream").lower()
            if not mime.startswith("image/"):
                continue
            b64 = base64.b64encode(part.inline_data.data).decode("ascii")
            ordered.append(f"data:{mime};base64,{b64}")

    return ordered


class ExplainResponse(BaseModel):
    text: str
    images: list[str]


@app.post("/api/explain", response_model=ExplainResponse)
async def explain(
    session_id: str = Form(...),
    user_input: str = Form(""),
    file: UploadFile | None = File(default=None),
) -> ExplainResponse:
    """
    Runs one agent turn for the given ``session_id``.

    Send JSON-compatible fields via **multipart/form-data**: ``session_id``, ``user_input``,
    and optional ``file`` (PDF). The PDF is attached to the user message as inline bytes
    for the model. A PDF is only accepted on the **first** turn of a session (no prior
    events); later turns must omit ``file``.
    """
    session_id = session_id.strip()
    user_input = (user_input or "").strip()
    user_id = DEFAULT_USER_ID

    pdf_bytes: bytes | None = None
    if file is not None and getattr(file, "filename", None):
        if not str(file.filename).lower().endswith(".pdf"):
            raise HTTPException(
                status_code=400, detail="Only PDF uploads are supported (.pdf)."
            )
        pdf_bytes = await file.read()
        if len(pdf_bytes) > MAX_PDF_BYTES:
            raise HTTPException(
                status_code=400,
                detail=f"PDF exceeds maximum size of {MAX_PDF_BYTES // (1024 * 1024)} MB.",
            )
        if not pdf_bytes:
            raise HTTPException(status_code=400, detail="Uploaded PDF is empty.")

    if not user_input and not pdf_bytes:
        raise HTTPException(
            status_code=400,
            detail="Provide non-empty user_input and/or a PDF file.",
        )

    try:
        await resolve_session(user_id, session_id)
    except Exception as exc:  # pragma: no cover - runtime guard
        logger.exception("Session resolution failed for session_id=%s", session_id)
        raise HTTPException(status_code=500, detail=str(exc)) from exc

    existing = await session_service.get_session(
        app_name=APP_NAME,
        user_id=user_id,
        session_id=session_id,
    )
    if (
        pdf_bytes is not None
        and existing
        and existing.events
        and len(existing.events) > 0
    ):
        raise HTTPException(
            status_code=400,
            detail="A PDF can only be attached on the first message of a conversation.",
        )

    parts: list[types.Part] = []
    if pdf_bytes is not None:
        parts.append(
            types.Part.from_bytes(data=pdf_bytes, mime_type="application/pdf")
        )
    if user_input:
        parts.append(types.Part.from_text(text=user_input))

    new_message = types.Content(role="user", parts=parts)

    collected: list[Event] = []

    try:
        async for event in runner.run_async(
            user_id=user_id,
            session_id=session_id,
            new_message=new_message,
        ):
            collected.append(event)
    except HTTPException:
        raise
    except Exception as exc:  # pragma: no cover - runtime guard
        logger.exception("Runner failed for session_id=%s", session_id)
        raise HTTPException(status_code=500, detail=str(exc)) from exc

    text = _gather_text_for_response(collected)
    images = await _collect_images_as_data_urls(
        collected,
        app_name=APP_NAME,
        user_id=user_id,
        session_id=session_id,
    )

    return ExplainResponse(text=text, images=images)



# Serve the static frontend in public/ at the site root, so a SINGLE container
# serves both the web page and the /api/explain endpoint (Hugging Face Spaces).
app.mount("/", StaticFiles(directory="public", html=True), name="static")

if __name__ == "__main__":
    logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO"))
    port = int(os.environ.get("PORT", "8080"))
    uvicorn.run("main:app", host="0.0.0.0", port=port)