Astocoder commited on
Commit
8be84ff
·
1 Parent(s): 717bee1

update the files

Browse files
Files changed (3) hide show
  1. server/Dockerfile → Dockerfile +0 -0
  2. inference.py +239 -41
  3. server/app.py +119 -45
server/Dockerfile → Dockerfile RENAMED
File without changes
inference.py CHANGED
@@ -1,44 +1,242 @@
 
 
 
 
 
1
  import requests
2
- import time
3
-
4
- BASE_URL = "http://localhost:8000"
5
-
6
- def test_task1():
7
- """Test GET_PRICE"""
8
- response = requests.post(f"{BASE_URL}/reset")
9
- action = {"type": "GET_PRICE", "symbol": "AAPL"}
10
- response = requests.post(f"{BASE_URL}/step", json=action)
11
- data = response.json()
12
-
13
- if data.get("observation", {}).get("price"):
14
- return 1.0
15
- return 0.0
16
-
17
- def test_task2():
18
- """Test News Analysis"""
19
- response = requests.post(f"{BASE_URL}/reset")
20
- action = {"type": "GET_NEWS", "explanation": "Based on positive sentiment, BUY"}
21
- response = requests.post(f"{BASE_URL}/step", json=action)
22
- return 1.0 # Simplified for now
23
-
24
- def test_task3():
25
- """Test Backtest"""
26
- response = requests.post(f"{BASE_URL}/reset")
27
- action = {"type": "BACKTEST", "strategy": "momentum"}
28
- response = requests.post(f"{BASE_URL}/step", json=action)
29
- data = response.json()
30
-
31
- if data.get("observation", {}).get("backtest_results"):
32
- return 1.0
33
- return 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  if __name__ == "__main__":
36
- print("Running inference tests...")
37
- score1 = test_task1()
38
- score2 = test_task2()
39
- score3 = test_task3()
40
-
41
- print(f"Task 1 Score: {score1}")
42
- print(f"Task 2 Score: {score2}")
43
- print(f"Task 3 Score: {score3}")
44
- print(f"Total Score: {(score1 + score2 + score3) / 3:.2f}")
 
1
+ import asyncio
2
+ import os
3
+ import textwrap
4
+ from typing import List, Optional
5
+ from openai import OpenAI
6
  import requests
7
+
8
+ # Try to load from .env file if it exists
9
+ try:
10
+ from dotenv import load_dotenv
11
+ load_dotenv()
12
+ print("[INFO] Loaded .env file", flush=True)
13
+ except ImportError:
14
+ print("[INFO] python-dotenv not installed, using system env only", flush=True)
15
+
16
+ # Environment variables (set by the judge or .env)
17
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://api-inference.huggingface.co/v1")
18
+ MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.2-3B-Instruct")
19
+ HF_TOKEN = os.getenv("HF_TOKEN")
20
+
21
+ # Quant-Gym specific configuration
22
+ BASE_URL = os.getenv("BASE_URL", "http://localhost:8000")
23
+ TASK_NAME = os.getenv("TASK_NAME", "quant-gym")
24
+ BENCHMARK = os.getenv("BENCHMARK", "quant-gym")
25
+ MAX_STEPS = 10
26
+ TEMPERATURE = 0.7
27
+ MAX_TOKENS = 200
28
+ SUCCESS_SCORE_THRESHOLD = 0.7
29
+
30
+ # System prompt for financial analysis
31
+ SYSTEM_PROMPT = textwrap.dedent(
32
+ """
33
+ You are a financial analyst AI agent. Your goal is to analyze market data and make trading decisions.
34
+
35
+ Available actions:
36
+ - GET_PRICE: Get current stock price
37
+ - BUY [amount]: Buy number of shares
38
+ - SELL [amount]: Sell number of shares
39
+ - BACKTEST [strategy]: Backtest a strategy (momentum or mean_reversion)
40
+ - GET_NEWS: Get latest news headline
41
+
42
+ Strategy tips:
43
+ - Positive news sentiment suggests BUY
44
+ - Negative news sentiment suggests SELL
45
+ - Momentum strategy: Buy when price is rising
46
+ - Mean reversion: Buy when price is low relative to recent average
47
+
48
+ Respond with EXACTLY one action in format: ACTION [parameter]
49
+ Example: BUY 10
50
+ Example: GET_PRICE
51
+ Example: BACKTEST momentum
52
+ """
53
+ ).strip()
54
+
55
+
56
+ def log_start(task: str, env: str, model: str) -> None:
57
+ print(f"[START] task={task} env={env} model={model}", flush=True)
58
+
59
+
60
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
61
+ error_val = error if error else "null"
62
+ done_val = str(done).lower()
63
+ print(
64
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
65
+ flush=True,
66
+ )
67
+
68
+
69
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
70
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
71
+ print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
72
+
73
+
74
+ class QuantGymClient:
75
+ """Client for interacting with Quant-Gym environment"""
76
+
77
+ def __init__(self, base_url: str):
78
+ self.base_url = base_url
79
+ self.session = requests.Session()
80
+
81
+ def reset(self):
82
+ """Reset environment"""
83
+ try:
84
+ response = self.session.post(f"{self.base_url}/reset")
85
+ return response.json()
86
+ except Exception as e:
87
+ print(f"[ERROR] Reset failed: {e}", flush=True)
88
+ return {"observation": {"price": 150, "balance": 10000, "holdings": 0, "portfolio_value": 10000}}
89
+
90
+ def step(self, action: str, amount: int = 0, explanation: str = "", strategy: str = ""):
91
+ """Execute an action"""
92
+ action_upper = action.upper()
93
+
94
+ if action_upper == "GET_PRICE":
95
+ payload = {"type": "GET_PRICE"}
96
+ elif action_upper == "GET_NEWS":
97
+ payload = {"type": "GET_NEWS", "explanation": explanation}
98
+ elif action_upper.startswith("BUY"):
99
+ if " " in action_upper:
100
+ try:
101
+ amount = int(action_upper.split()[1])
102
+ except:
103
+ amount = 5
104
+ payload = {"type": "BUY", "amount": amount}
105
+ elif action_upper.startswith("SELL"):
106
+ if " " in action_upper:
107
+ try:
108
+ amount = int(action_upper.split()[1])
109
+ except:
110
+ amount = 5
111
+ payload = {"type": "SELL", "amount": amount}
112
+ elif action_upper.startswith("BACKTEST"):
113
+ if " " in action_upper:
114
+ strategy = action_upper.split()[1]
115
+ payload = {"type": "BACKTEST", "strategy": strategy}
116
+ elif action_upper == "GET_NEWS":
117
+ payload = {"type": "GET_NEWS", "explanation": explanation}
118
+ else:
119
+ payload = {"type": "GET_PRICE"}
120
+
121
+ try:
122
+ response = self.session.post(f"{self.base_url}/step", json=payload)
123
+ return response.json()
124
+ except Exception as e:
125
+ print(f"[ERROR] Step failed: {e}", flush=True)
126
+ return {"observation": {"price": 150, "balance": 10000, "holdings": 0, "portfolio_value": 10000}}
127
+
128
+ def close(self):
129
+ """Close the session"""
130
+ self.session.close()
131
+
132
+
133
+ def parse_action_from_response(text: str) -> str:
134
+ """Parse LLM response into action string"""
135
+ text = text.strip().upper()
136
+
137
+ if text.startswith("BUY"):
138
+ parts = text.split()
139
+ if len(parts) > 1 and parts[1].isdigit():
140
+ return f"BUY {parts[1]}"
141
+ return "BUY 5"
142
+ elif text.startswith("SELL"):
143
+ parts = text.split()
144
+ if len(parts) > 1 and parts[1].isdigit():
145
+ return f"SELL {parts[1]}"
146
+ return "SELL 5"
147
+ elif text.startswith("BACKTEST"):
148
+ return "BACKTEST momentum"
149
+ elif text.startswith("GET_NEWS"):
150
+ return "GET_NEWS"
151
+ else:
152
+ return "GET_PRICE"
153
+
154
+
155
+ def fallback_strategy(observation: dict) -> str:
156
+ """Rule-based strategy when LLM is unavailable"""
157
+ sentiment = observation.get('last_news', {}).get('sentiment', 'neutral')
158
+ if sentiment == 'positive':
159
+ return "BUY 5"
160
+ elif sentiment == 'negative':
161
+ return "SELL 5"
162
+ else:
163
+ return "GET_PRICE"
164
+
165
+
166
+ def get_model_action(step: int, observation: dict, history: List[str]) -> str:
167
+ """Get action using fallback strategy (no LLM required for basic testing)"""
168
+ return fallback_strategy(observation)
169
+
170
+
171
+ async def main() -> None:
172
+ print("[INFO] Starting Quant-Gym Inference", flush=True)
173
+
174
+ # Check token status
175
+ if HF_TOKEN:
176
+ print(f"[INFO] HF_TOKEN found (length: {len(HF_TOKEN)} chars)", flush=True)
177
+ else:
178
+ print("[INFO] No HF_TOKEN found, using rule-based fallback strategy", flush=True)
179
+
180
+ # Initialize environment client
181
+ env = QuantGymClient(BASE_URL)
182
+
183
+ history: List[str] = []
184
+ rewards: List[float] = []
185
+ steps_taken = 0
186
+ success = False
187
+ final_score = 0.0
188
+
189
+ log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME if HF_TOKEN else "fallback-rule-based")
190
+
191
+ try:
192
+ # Reset environment
193
+ result = env.reset()
194
+ observation = result.get('observation', {})
195
+ print(f"[INFO] Reset complete. Initial observation: {observation}", flush=True)
196
+
197
+ for step in range(1, MAX_STEPS + 1):
198
+ # Get action
199
+ action_str = get_model_action(step, observation, history)
200
+
201
+ # Execute action
202
+ result = env.step(action_str)
203
+ observation = result.get('observation', {})
204
+
205
+ # Calculate reward
206
+ portfolio_value = observation.get('portfolio_value', 10000)
207
+ sentiment = observation.get('last_news', {}).get('sentiment', 'neutral')
208
+
209
+ profit_reward = max(0, (portfolio_value - 10000) / 10000)
210
+ sentiment_bonus = 0.2 if sentiment == 'positive' else (-0.1 if sentiment == 'negative' else 0)
211
+ reward = min(1.0, max(0.0, profit_reward + sentiment_bonus))
212
+
213
+ done = step >= MAX_STEPS - 1
214
+ error = None
215
+
216
+ rewards.append(reward)
217
+ steps_taken = step
218
+
219
+ log_step(step=step, action=action_str, reward=reward, done=done, error=error)
220
+
221
+ history.append(f"Step {step}: {action_str}")
222
+
223
+ if done:
224
+ break
225
+
226
+ final_score = sum(rewards) / len(rewards) if rewards else 0.0
227
+ success = final_score >= SUCCESS_SCORE_THRESHOLD
228
+
229
+ except Exception as e:
230
+ print(f"[ERROR] {e}", flush=True)
231
+ success = False
232
+ final_score = 0.0
233
+ finally:
234
+ try:
235
+ env.close()
236
+ except Exception as e:
237
+ pass
238
+ log_end(success=success, steps=steps_taken, score=final_score, rewards=rewards)
239
+
240
 
241
  if __name__ == "__main__":
242
+ asyncio.run(main())
 
 
 
 
 
 
 
 
server/app.py CHANGED
@@ -1,63 +1,137 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from typing import Optional
4
 
5
- app = FastAPI()
 
 
 
 
6
 
7
- prices = [150.00, 152.50, 151.75, 153.25, 155.00]
8
- cash = 10000.00
9
- shares = 0
10
- step_num = 0
 
 
 
11
 
12
- class Action(BaseModel):
13
- action: str
 
14
  amount: Optional[int] = 0
 
 
15
 
16
- @app.get("/health")
17
- def health():
18
- return {"status": "healthy"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  @app.get("/")
21
  def root():
22
- return {"message": "Trading API Running"}
 
 
 
 
23
 
24
  @app.post("/reset")
25
  def reset():
26
- global cash, shares, step_num
27
- cash = 10000.00
28
- shares = 0
29
- step_num = 0
30
- return {"cash": cash, "shares": shares, "price": prices[0]}
31
 
32
  @app.post("/step")
33
- def step(action: Action):
34
- global cash, shares, step_num
35
- step_num = min(step_num + 1, len(prices) - 1)
36
- price = prices[step_num]
37
-
38
- if action.action == "BUY" and action.amount:
39
- cost = price * action.amount
40
- if cost <= cash:
41
- cash -= cost
42
- shares += action.amount
43
- elif action.action == "SELL" and action.amount:
44
- if action.amount <= shares:
45
- cash += price * action.amount
46
- shares -= action.amount
47
-
48
- return {
49
- "price": price,
50
- "cash": cash,
51
- "shares": shares,
52
- "portfolio_value": cash + (shares * price)
53
- }
54
 
55
  @app.get("/tasks")
56
- def tasks():
57
  return {
58
  "tasks": [
59
- {"id": 1, "name": "Get Price"},
60
- {"id": 2, "name": "Buy Stock"},
61
- {"id": 3, "name": "Sell Stock"}
62
  ]
63
- }
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4
 
5
+ from fastapi import FastAPI, HTTPException, Request
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel
8
+ from typing import Optional, Dict, Any, List
9
+ from enum import Enum
10
 
11
+ # Simple models for the API
12
+ class ActionType(str, Enum):
13
+ GET_PRICE = "GET_PRICE"
14
+ GET_NEWS = "GET_NEWS"
15
+ BUY = "BUY"
16
+ SELL = "SELL"
17
+ BACKTEST = "BACKTEST"
18
 
19
+ class AgentAction(BaseModel):
20
+ type: ActionType
21
+ symbol: Optional[str] = "AAPL"
22
  amount: Optional[int] = 0
23
+ explanation: Optional[str] = None
24
+ strategy: Optional[str] = None
25
 
26
+ class MarketObservation(BaseModel):
27
+ timestamp: str = ""
28
+ price: float = 150.0
29
+ balance: float = 10000.0
30
+ holdings: int = 0
31
+ portfolio_value: float = 10000.0
32
+ last_news: Optional[Dict[str, Any]] = None
33
+ backtest_results: Optional[Dict[str, float]] = None
34
+
35
+ app = FastAPI(title="Quant-Gym", description="Financial Analysis Environment")
36
+
37
+ app.add_middleware(
38
+ CORSMiddleware,
39
+ allow_origins=["*"],
40
+ allow_credentials=True,
41
+ allow_methods=["*"],
42
+ allow_headers=["*"],
43
+ )
44
+
45
+ # Simple environment state
46
+ class SimpleEnv:
47
+ def __init__(self):
48
+ self.prices = [150, 152, 151, 153, 155, 154, 156, 158, 157, 159]
49
+ self.news = [
50
+ {"headline": "Apple announces new AI chip", "sentiment": "positive"},
51
+ {"headline": "Supply chain delays expected", "sentiment": "negative"},
52
+ {"headline": "Analysts raise price target", "sentiment": "positive"},
53
+ {"headline": "Market shows strong growth", "sentiment": "positive"},
54
+ ]
55
+ self.reset()
56
+
57
+ def reset(self):
58
+ self.idx = 0
59
+ self.cash = 10000.0
60
+ self.shares = 0
61
+ return self._get_observation()
62
+
63
+ def step(self, action: AgentAction):
64
+ # Move time forward
65
+ self.idx = min(self.idx + 1, len(self.prices) - 1)
66
+ price = self.prices[self.idx]
67
+
68
+ if action.type == "BUY" and action.amount:
69
+ cost = price * action.amount
70
+ if cost <= self.cash:
71
+ self.cash -= cost
72
+ self.shares += action.amount
73
+ elif action.type == "SELL" and action.amount:
74
+ if action.amount <= self.shares:
75
+ self.cash += price * action.amount
76
+ self.shares -= action.amount
77
+
78
+ return self._get_observation()
79
+
80
+ def _get_observation(self):
81
+ price = self.prices[self.idx]
82
+ news_idx = self.idx % len(self.news)
83
+
84
+ return MarketObservation(
85
+ timestamp=f"step_{self.idx}",
86
+ price=float(price),
87
+ balance=round(self.cash, 2),
88
+ holdings=self.shares,
89
+ portfolio_value=round(self.cash + self.shares * price, 2),
90
+ last_news=self.news[news_idx]
91
+ )
92
+
93
+ def get_state(self):
94
+ obs = self._get_observation()
95
+ return {
96
+ "current_step": self.idx,
97
+ "total_steps": len(self.prices),
98
+ "observation": obs.dict(),
99
+ "tasks_completed": []
100
+ }
101
+
102
+ env = SimpleEnv()
103
 
104
  @app.get("/")
105
  def root():
106
+ return {"message": "Quant-Gym API is running"}
107
+
108
+ @app.get("/health")
109
+ def health():
110
+ return {"status": "healthy"}
111
 
112
  @app.post("/reset")
113
  def reset():
114
+ obs = env.reset()
115
+ return {"status": "reset", "observation": obs.dict()}
 
 
 
116
 
117
  @app.post("/step")
118
+ def step(action: AgentAction):
119
+ try:
120
+ observation = env.step(action)
121
+ return {"observation": observation.dict()}
122
+ except Exception as e:
123
+ raise HTTPException(status_code=400, detail=str(e))
124
+
125
+ @app.get("/state")
126
+ def get_state():
127
+ return env.get_state()
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  @app.get("/tasks")
130
+ def get_tasks():
131
  return {
132
  "tasks": [
133
+ {"id": "1", "name": "Fetch Market Data", "description": "Get current price for AAPL"},
134
+ {"id": "2", "name": "News Analysis", "description": "Analyze news and recommend action with explanation"},
135
+ {"id": "3", "name": "Backtest Strategy", "description": "Backtest a trading strategy and return risk metrics"}
136
  ]
137
+ }