File size: 7,404 Bytes
2a92b3a
 
 
 
 
 
 
 
 
90fc756
 
 
 
 
 
2a92b3a
 
 
 
 
 
90fc756
 
 
2a92b3a
 
90fc756
 
 
 
 
2a92b3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90fc756
 
 
 
 
 
 
 
 
 
 
2a92b3a
 
 
 
 
 
 
 
 
90fc756
 
 
 
2a92b3a
 
90fc756
 
2a92b3a
 
90fc756
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a92b3a
90fc756
 
 
2a92b3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90fc756
2a92b3a
 
90fc756
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a92b3a
 
90fc756
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a92b3a
90fc756
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
"""
SQL Query Reviewer — Client
============================
Supports three connection modes:
1. from_docker_image()  — used by hackathon validator
2. Async via SQLReviewEnv(base_url=...)
3. Sync via SyncSQLReviewEnv(base_url=...)
"""

from __future__ import annotations

from typing import Any

import httpx

from sql_query_reviewer.models import (
    ResetRequest,
    SQLReviewAction,
    SQLReviewState,
    StepResult,
)


class SQLReviewEnv:
    """Async client for the SQL Query Reviewer environment."""

    def __init__(self, base_url: str, timeout: float = 30.0) -> None:
        self.base_url = base_url.rstrip("/")
        self.timeout = timeout
        self._client: httpx.AsyncClient | None = None

    # --- Docker image support (hackathon validator) -----------------------

    @classmethod
    async def from_docker_image(cls, image_name: str) -> "SQLReviewEnv":
        """
        Connect to the environment via a Docker image.
        Tries openenv-core's provider first, then falls back to localhost.
        """
        try:
            # Try using openenv-core's built-in Docker provider
            from openenv.core.env_client import EnvClient

            class _Wrapper(EnvClient):
                pass

            env = await _Wrapper.from_docker_image(image_name)
            # Wrap the openenv client so our typed models work
            return _DockerEnvWrapper(env)
        except ImportError:
            pass
        except Exception:
            pass

        # Fallback: assume the Docker container is already running on port 8000
        import subprocess
        import time

        container_id = None
        try:
            result = subprocess.run(
                ["docker", "run", "-d", "-p", "8000:8000", image_name],
                capture_output=True,
                text=True,
                timeout=120,
            )
            container_id = result.stdout.strip()
        except Exception:
            pass

        # Wait for container to be ready
        base_url = "http://localhost:8000"
        for _ in range(30):
            try:
                async with httpx.AsyncClient() as c:
                    r = await c.get(f"{base_url}/health", timeout=2.0)
                    if r.status_code == 200:
                        break
            except Exception:
                pass
            time.sleep(1)

        instance = cls(base_url=base_url)
        instance._container_id = container_id  # type: ignore[attr-defined]
        await instance.__aenter__()
        return instance

    # --- Async context manager --------------------------------------------

    async def __aenter__(self) -> "SQLReviewEnv":
        self._client = httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout)
        return self

    async def __aexit__(self, *_: Any) -> None:
        await self.close()

    async def close(self) -> None:
        if self._client is not None:
            await self._client.aclose()
            self._client = None
        # Clean up Docker container if we started one
        container_id = getattr(self, "_container_id", None)
        if container_id:
            try:
                import subprocess
                subprocess.run(["docker", "stop", container_id], capture_output=True, timeout=10)
                subprocess.run(["docker", "rm", container_id], capture_output=True, timeout=10)
            except Exception:
                pass

    def sync(self) -> "SyncSQLReviewEnv":
        return SyncSQLReviewEnv(base_url=self.base_url, timeout=self.timeout)

    # --- API methods ------------------------------------------------------

    async def reset(self, task_id: str | None = None) -> StepResult:
        client = self._require_client()
        body = ResetRequest(task_id=task_id).model_dump(exclude_none=True)
        response = await client.post("/reset", json=body)
        response.raise_for_status()
        return StepResult.model_validate(response.json())

    async def step(self, action: SQLReviewAction) -> StepResult:
        client = self._require_client()
        response = await client.post("/step", json=action.model_dump(exclude_none=True))
        response.raise_for_status()
        return StepResult.model_validate(response.json())

    async def state(self) -> SQLReviewState:
        client = self._require_client()
        response = await client.get("/state")
        response.raise_for_status()
        return SQLReviewState.model_validate(response.json())

    def _require_client(self) -> httpx.AsyncClient:
        if self._client is None:
            raise RuntimeError("Use SQLReviewEnv as an async context manager or call from_docker_image().")
        return self._client


class _DockerEnvWrapper(SQLReviewEnv):
    """Wraps an openenv-core EnvClient to present our typed interface."""

    def __init__(self, inner: Any) -> None:
        self._inner = inner
        self._client = None  # not used — we delegate to inner
        self.base_url = ""

    async def reset(self, task_id: str | None = None) -> StepResult:
        result = await self._inner.reset()
        return StepResult.model_validate(result.model_dump() if hasattr(result, "model_dump") else result)

    async def step(self, action: SQLReviewAction) -> StepResult:
        result = await self._inner.step(action)
        return StepResult.model_validate(result.model_dump() if hasattr(result, "model_dump") else result)

    async def state(self) -> SQLReviewState:
        result = await self._inner.state()
        return SQLReviewState.model_validate(result.model_dump() if hasattr(result, "model_dump") else result)

    async def close(self) -> None:
        try:
            await self._inner.close()
        except Exception:
            pass


class SyncSQLReviewEnv:
    """Synchronous client for local dev and testing."""

    def __init__(self, base_url: str, timeout: float = 30.0) -> None:
        self.base_url = base_url.rstrip("/")
        self.timeout = timeout
        self._client: httpx.Client | None = None

    def __enter__(self) -> "SyncSQLReviewEnv":
        self._client = httpx.Client(base_url=self.base_url, timeout=self.timeout)
        return self

    def __exit__(self, *_: Any) -> None:
        self.close()

    def close(self) -> None:
        if self._client is not None:
            self._client.close()
            self._client = None

    def reset(self, task_id: str | None = None) -> StepResult:
        client = self._require_client()
        body = ResetRequest(task_id=task_id).model_dump(exclude_none=True)
        response = client.post("/reset", json=body)
        response.raise_for_status()
        return StepResult.model_validate(response.json())

    def step(self, action: SQLReviewAction) -> StepResult:
        client = self._require_client()
        response = client.post("/step", json=action.model_dump(exclude_none=True))
        response.raise_for_status()
        return StepResult.model_validate(response.json())

    def state(self) -> SQLReviewState:
        client = self._require_client()
        response = client.get("/state")
        response.raise_for_status()
        return SQLReviewState.model_validate(response.json())

    def _require_client(self) -> httpx.Client:
        if self._client is None:
            raise RuntimeError("Use SyncSQLReviewEnv as a context manager.")
        return self._client