sql-migration-env / client.py
Eishaan's picture
Initial Submission: SQL Schema Migration Agent for OpenEnv Hackathon
1b42f19
"""
SQL Migration Environment Client.
Provides the client for connecting to a SQL Migration Environment server.
Extends the base OpenEnv EnvClient for WebSocket-based persistent sessions.
Example:
>>> from sql_migration_env import DbMigrationEnv
>>>
>>> env = DbMigrationEnv(base_url="http://localhost:7860").sync()
>>> with env:
... result = env.reset()
... result = env.step({"sql_command": "ALTER TABLE users RENAME COLUMN first_name TO full_name", "reasoning": "test"})
... print(result.observation)
"""
from typing import Any, Dict
from openenv.core.env_client import EnvClient
from openenv.core.client_types import StepResult
from .models import MigrationAction, MigrationObservation, MigrationState
class DbMigrationEnv(EnvClient):
"""
Client for the SQL Migration Environment.
Inherits connection management, async/sync wrappers, and Docker/HF Space
support from EnvClient. Provides typed step/reset interactions.
Example:
>>> async with DbMigrationEnv(base_url="http://localhost:7860") as env:
... result = await env.reset(task_name="column-restructure")
... while not result.done:
... action = {"sql_command": "...", "reasoning": "...", "submit_final": False}
... result = await env.step(action)
... print(f"Final score: {result.observation.get('migration_progress', 0)}")
Example with sync wrapper:
>>> env = DbMigrationEnv(base_url="http://localhost:7860").sync()
>>> with env:
... result = env.reset()
... print(result.observation)
"""
def _step_payload(self, action: Any) -> Dict[str, Any]:
"""Convert action to JSON payload for the server."""
if isinstance(action, MigrationAction):
return action.model_dump()
elif isinstance(action, dict):
return action
else:
raise ValueError(f"Expected MigrationAction or dict, got {type(action)}")
def _parse_result(self, payload: Dict[str, Any]) -> StepResult:
"""Parse server response into StepResult."""
observation = payload.get("observation", {})
reward = payload.get("reward")
done = payload.get("done", False)
return StepResult(
observation=observation,
reward=reward,
done=done,
)
def _parse_state(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Parse state response."""
return payload