test1978 commited on
Commit
d0d3ba8
Β·
verified Β·
1 Parent(s): 8259be7

update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -100
app.py CHANGED
@@ -1,121 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  from huggingface_hub import hf_hub_download
4
- import os
5
 
6
  TOKEN = os.environ.get("HF_TOKEN")
7
 
8
-
9
- # MODEL repo (model type)
10
- model_repo = "test1978/breakthrough-model"
11
- model_path = hf_hub_download(
12
- model_repo,
 
 
 
13
  "breakthrough_mcvs.py",
14
  repo_type="model",
15
  token=TOKEN,
16
  )
17
- with open(model_path, "r", encoding="utf-8-sig") as f:
18
- exec(f.read())
19
 
20
-
21
- # DB dataset (dataset type!)
22
- db_repo = "test1978/breakthrough-data"
23
- db_path = hf_hub_download(
24
- db_repo,
25
  "breakthrough_zone_db.npz",
26
  repo_type="dataset",
27
  token=TOKEN,
28
  )
29
- zonedb_data = np.load(db_path, allow_pickle=True)
30
-
31
-
32
- # Init YOUR zone DB
33
- zonedb = HilbertOrderedZoneDatabase()
34
- zonedb.winning_matrices = list(zonedb_data.get("winning", []))
35
- zonedb.losing_matrices = list(zonedb_data.get("losing", []))
36
- zonedb.draw_matrices = list(zonedb_data.get("draw", []))
37
 
 
 
 
 
38
 
39
- def parse_board(board_text):
40
- print(f"[DEBUG] parse_board raw input:\n{repr(board_text)}")
41
- lines = board_text.strip().splitlines()
42
- rows = [line.strip() for line in lines if line.strip()]
43
- if len(rows) != 8:
44
- raise ValueError(f"Board must have exactly 8 rows (got {len(rows)}), raw lines:\n{lines}")
45
 
46
- board = np.zeros((8, 8), dtype=np.int32)
47
- mapping = {".": 0, "1": 1, "2": 2}
48
- for r, line in enumerate(rows):
49
- parts = line.split()
50
- if len(parts) != 8:
51
- raise ValueError(f"Each row must have 8 space-separated cells (row {r+1}: {line})")
52
- for c, cell in enumerate(parts):
53
- if cell not in mapping:
54
- raise ValueError(f"Use '.', '1', or '2' only (cell {r},{c}: {cell})")
55
- board[r, c] = mapping[cell]
56
- return board
 
 
 
 
 
 
 
 
57
 
58
 
59
- # ALIGN: coordinate system with bot_runner
60
- # Rows are 0-indexed, row 0 = TOP (black side), row 7 = BOTTOM (white side).
61
- # UCI rank = row + 1 (so row 0 -> rank 1, row 7 -> rank 8).
62
- def move_to_uci(move):
63
  fr, fc, tr, tc = move
64
- cols = "abcdefgh"
65
- return cols[fc] + str(fr + 1) + cols[tc] + str(tr + 1)
66
 
67
 
68
- def fen_to_board(fen):
69
- rows_str = fen.strip().split(" ")[0] # strip off ' w' or ' b'
70
- parts = rows_str.split("/")
71
- if len(parts) != 8:
72
- raise ValueError(f"FEN must have 8 rows, got {len(parts)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  board = np.zeros((8, 8), dtype=np.int32)
74
- for r, row in enumerate(parts):
75
  c = 0
76
- for ch in row:
77
  if ch == "B":
78
- board[r, c] = 2
 
79
  c += 1
80
  elif ch == "W":
81
- board[r, c] = 1
 
82
  c += 1
83
  elif ch.isdigit():
84
- skip = int(ch)
85
- c += skip
86
  else:
87
- raise ValueError(f"Invalid FEN char {ch} in row {r} ({row})")
88
  if c != 8:
89
- raise ValueError(f"Row {r} has {c} columns, should be 8")
90
  return board
91
 
92
- # ALIGN: only return legal moves in get_move
93
- def get_move(board_text, player=None):
94
- print(f"[DEBUG] get_move FEN board_text={repr(board_text)}, player={player}")
95
- game = Breakthrough()
96
- game.board = fen_to_board(board_text)
97
- if player == "1":
98
- game.move_count = 0
99
- elif player == "2":
100
- game.move_count = 1
101
- else:
102
- # Default: infer from FEN suffix or assume white
103
- game.move_count = 0 if " w" in board_text.lower() else 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  game._cached_matrix = None
105
- searcher = MCVSSearcher(None, None, zonedb, lambda_zone=1.0, k_zone=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  visits, _ = searcher.search_with_time_budget(game, 1.0)
107
 
108
- # Debug: show how many legal moves and visits
109
- legal_moves = list(game.get_legal_moves())
110
- legal_set = set(legal_moves)
111
- print(f"[DEBUG] legal moves count: {len(legal_moves)}")
112
- if legal_moves:
113
- print(f"[DEBUG] sample legal move: {legal_moves[0]}")
114
  print(f"[DEBUG] visits count: {len(visits)}")
115
  if visits:
116
  print(f"[DEBUG] sample visit key: {next(iter(visits))}")
117
 
118
- # Pick the best visited move that is also legal
 
 
119
  best_move = None
120
  if visits:
121
  for candidate in sorted(visits, key=visits.get, reverse=True):
@@ -123,39 +205,53 @@ def get_move(board_text, player=None):
123
  best_move = candidate
124
  break
125
  if best_move is None:
126
- print(f"[WARNING] no visited move is legal; visited={list(visits.keys())[:5]}")
127
 
128
- # Fall back to first legal move if needed
129
  if best_move is None:
130
- if legal_moves:
131
- best_move = legal_moves[0]
132
  print(f"[WARNING] using first legal move as fallback: {best_move}")
133
  else:
134
- # No legal moves at all β€” return a safe dummy
135
- print("[WARNING] no legal moves in position; returning dummy move")
136
- return move_to_uci((3, 3, 3, 3))
 
 
 
137
 
138
- move = move_to_uci(best_move)
139
- print(f"[DEBUG] get_move OK: FEN={board_text} -> {move}")
140
- return move
141
 
 
 
 
 
 
 
 
 
 
 
 
142
  demo = gr.Interface(
143
  fn=get_move,
144
- inputs=gr.Textbox(
145
- label="Breakthrough Board (1/2/.)",
146
- value=(
147
- "1 1 1 1 1 1 1 1\n"
148
- "1 1 1 1 1 1 1 1\n"
149
- ". . . . . . . .\n"
150
- ". . . . . . . .\n"
151
- ". . . . . . . .\n"
152
- ". . . . . . . .\n"
153
- "2 2 2 2 2 2 2 2\n"
154
- "2 2 2 2 2 2 2 2"
155
  ),
156
- ),
 
 
 
 
 
157
  outputs=gr.Textbox(label="UCI Move"),
158
- title="🎯 BreakthroughMCVS Secure (Model + Dataset)",
 
 
 
 
 
159
  )
160
 
161
  demo.launch()
 
1
+ # ──────────────────────────────────────────────────────────────────────────────
2
+ # space_app.py β€” HF Gradio Space for Breakthrough move prediction
3
+ #
4
+ # Deploy this file as app.py in your Hugging Face Space.
5
+ #
6
+ # Coordinate system (identical to apps/games/breakthrough_engine.py):
7
+ # rows 0–7, 0-indexed
8
+ # row 0 = rank 8 = TOP (Black / 'B' side)
9
+ # row 7 = rank 1 = BOTTOM (White / 'W' side)
10
+ # cols 0–7 β†’ 'a'–'h'
11
+ # rank formula: rank = 8 βˆ’ row
12
+ #
13
+ # Piece β†’ MCVS player mapping (matches engine directions):
14
+ # 'B' β†’ PLAYER1 = 1, direction = +1 (moves DOWN, row increases toward rank 1)
15
+ # 'W' β†’ PLAYER2 = 2, direction = βˆ’1 (moves UP, row decreases toward rank 8)
16
+ #
17
+ # move_to_uci is the exact inverse of breakthrough_engine._sq_to_coords:
18
+ # (fr, fc, tr, tc) β†’ cols[fc] + str(8βˆ’fr) + cols[tc] + str(8βˆ’tr)
19
+ #
20
+ # Gradio endpoint exposed:
21
+ # api_name="get_move" β†’ /gradio_api/call/get_move (Gradio 4 SSE queue)
22
+ # inputs: [fen: str, player: str] β†’ POST {"data": [fen, player]}
23
+ # output: uci_move: str ← SSE data: ["e4e3"]
24
+ # ──────────────────────────────────────────────────────────────────────────────
25
+ import os
26
+
27
  import gradio as gr
28
  import numpy as np
29
  from huggingface_hub import hf_hub_download
 
30
 
31
  TOKEN = os.environ.get("HF_TOKEN")
32
 
33
+ # ─────────────────────────────────────────────────────────────────────────────
34
+ # 1. Load breakthrough_mcvs.py from Hub model repo into this namespace.
35
+ # After exec() the following names are available globally:
36
+ # Breakthrough, MCVSSearcher, HilbertOrderedZoneDatabase,
37
+ # ABCModelDynamic, WeightedMatrixABC, move_to_index, …
38
+ # ─────────────────────────────────────────────────────────────────────────────
39
+ _model_path = hf_hub_download(
40
+ "test1978/breakthrough-model",
41
  "breakthrough_mcvs.py",
42
  repo_type="model",
43
  token=TOKEN,
44
  )
45
+ with open(_model_path, "r", encoding="utf-8-sig") as _fh:
46
+ exec(_fh.read(), globals()) # noqa: S102 β€” trusted internal model file
47
 
48
+ # ─────────────────────────────────────────────────────────────────────────────
49
+ # 2. Load the zone database from Hub dataset repo.
50
+ # ─────────────────────────────────────────────────────────────────────────────
51
+ _db_path = hf_hub_download(
52
+ "test1978/breakthrough-data",
53
  "breakthrough_zone_db.npz",
54
  repo_type="dataset",
55
  token=TOKEN,
56
  )
57
+ _db_data = np.load(_db_path, allow_pickle=True)
 
 
 
 
 
 
 
58
 
59
+ zonedb = HilbertOrderedZoneDatabase() # noqa: F821 β€” defined by exec above
60
+ zonedb.winning_matrices = list(_db_data.get("winning", []))
61
+ zonedb.losing_matrices = list(_db_data.get("losing", []))
62
+ zonedb.draw_matrices = list(_db_data.get("draw", []))
63
 
64
+ print(
65
+ f"[INIT] Zone DB loaded: "
66
+ f"W={len(zonedb.winning_matrices)} "
67
+ f"L={len(zonedb.losing_matrices)} "
68
+ f"D={len(zonedb.draw_matrices)}"
69
+ )
70
 
71
+ # ─────────────────────────────────────────────────────────────────────────────
72
+ # 3. move_to_uci β€” exact inverse of breakthrough_engine._sq_to_coords
73
+ #
74
+ # breakthrough_engine._coords_to_sq(row, col):
75
+ # rank_idx = 7 - row # row 0 β†’ rank_idx 7 β†’ rank "8"
76
+ # return f"{_FILES[col]}{_RANKS[rank_idx]}" # _RANKS = "12345678"
77
+ # ⟹ rank = rank_idx + 1 = 8 - row
78
+ #
79
+ # Therefore:
80
+ # move_to_uci((fr, fc, tr, tc)) = cols[fc] + str(8βˆ’fr) + cols[tc] + str(8βˆ’tr)
81
+ #
82
+ # Verification:
83
+ # (4, 1, 3, 1) β†’ 'b' + str(4) + 'b' + str(5) = 'b4b5'
84
+ # bot-runner: 'b4' = _sq_to_coords('b4') = (4,1) βœ“
85
+ # 'b5' = _sq_to_coords('b5') = (3,1) βœ“ β†’ move (4,1)β†’(3,1)
86
+ # (7, 0, 6, 0) β†’ 'a1a2' (White, row 7β†’6, rank 1β†’2) βœ“
87
+ # (0, 3, 1, 3) β†’ 'd8d7' (Black, row 0β†’1, rank 8β†’7) βœ“
88
+ # ─────────────────────────────────────────────────────────────────────────────
89
+ _COLS = "abcdefgh"
90
 
91
 
92
+ def move_to_uci(move: tuple) -> str:
 
 
 
93
  fr, fc, tr, tc = move
94
+ # rank = 8 - row (row 0 = rank 8, row 7 = rank 1)
95
+ return _COLS[fc] + str(8 - fr) + _COLS[tc] + str(8 - tr)
96
 
97
 
98
+ # ─────────────────────────────────────────────────────────────────────────────
99
+ # 4. fen_to_board β€” FEN string β†’ Breakthrough.board (numpy int32 array, 8Γ—8)
100
+ #
101
+ # FEN example: "BBBBBBBB/BBBBBBBB/8/8/8/8/WWWWWWWW/WWWWWWWW w"
102
+ # First rank string = row 0 = rank 8 = top (Black side)
103
+ # Last rank string = row 7 = rank 1 = bottom (White side)
104
+ #
105
+ # Mapping (must match Breakthrough.get_legal_moves() directions):
106
+ # 'B' β†’ 1 = PLAYER1 (direction=+1, moves DOWN, row increases)
107
+ # 'W' β†’ 2 = PLAYER2 (direction=βˆ’1, moves UP, row decreases)
108
+ #
109
+ # Why B→PLAYER1 and not PLAYER2?
110
+ # In breakthrough_mcvs.py, PLAYER1 starts at board[0:2,:] (top rows) with
111
+ # direction=+1, so it naturally represents the top-side piece ('B').
112
+ # PLAYER2 starts at board[6:8,:] (bottom rows) with direction=βˆ’1, matching 'W'.
113
+ # ─────────────────────────────────────────────────────────────────────────────
114
+ def fen_to_board(fen: str) -> np.ndarray:
115
+ ranks_str = fen.strip().split(" ")[0]
116
+ rank_parts = ranks_str.split("/")
117
+ if len(rank_parts) != 8:
118
+ raise ValueError(f"FEN must have 8 ranks separated by '/'; got {len(rank_parts)}: {fen!r}")
119
+
120
  board = np.zeros((8, 8), dtype=np.int32)
121
+ for r, rank_str in enumerate(rank_parts):
122
  c = 0
123
+ for ch in rank_str:
124
  if ch == "B":
125
+ # Black piece at top rows β†’ PLAYER1 (direction=+1 downward)
126
+ board[r, c] = 1
127
  c += 1
128
  elif ch == "W":
129
+ # White piece at bottom rows β†’ PLAYER2 (direction=βˆ’1 upward)
130
+ board[r, c] = 2
131
  c += 1
132
  elif ch.isdigit():
133
+ c += int(ch) # empty squares
 
134
  else:
135
+ raise ValueError(f"Unexpected FEN char {ch!r} in rank {r}: {rank_str!r}")
136
  if c != 8:
137
+ raise ValueError(f"Rank {r} has {c} columns, expected 8: {rank_str!r}")
138
  return board
139
 
140
+
141
+ # ─────────────────────────────────────────────────────────────────────────────
142
+ # 5. get_move β€” main function exposed via Gradio
143
+ #
144
+ # Parameters
145
+ # ----------
146
+ # fen : Breakthrough FEN, e.g. "BBBBBBBB/8/8/8/8/8/8/WWWWWWWW w"
147
+ # player : who moves next; accepted values:
148
+ # "w" or "2" β†’ White (PLAYER2, move_count = 1 = odd)
149
+ # "b" or "1" β†’ Black (PLAYER1, move_count = 0 = even)
150
+ # "" or None β†’ infer from FEN side-to-move suffix
151
+ # " w" β†’ White (move_count=1)
152
+ # " b" β†’ Black (move_count=0)
153
+ #
154
+ # Returns
155
+ # -------
156
+ # uci : str β€” a legal UCI move string ("e4e3", "b4b5", …).
157
+ # If search produces no legal move and no legal move exists at all,
158
+ # returns "0000" (a safe sentinel the bot-runner recognises as no-op).
159
+ # ─────────────────────────────────────────────────────────────────────────────
160
+ def get_move(fen: str, player: str = "") -> str:
161
+ print(f"[DEBUG] get_move FEN board_text={fen!r}, player={player!r}")
162
+
163
+ # ── Build game state ──────────────────────────────────────────────────────
164
+ game = Breakthrough() # noqa: F821
165
+ game.board = fen_to_board(fen)
166
  game._cached_matrix = None
167
+
168
+ # Determine which player moves next (move_count parity):
169
+ # even β†’ PLAYER1 (Black / 'b'), odd β†’ PLAYER2 (White / 'w')
170
+ if player in ("w", "2"):
171
+ game.move_count = 1 # White = PLAYER2 = odd move
172
+ elif player in ("b", "1"):
173
+ game.move_count = 0 # Black = PLAYER1 = even move
174
+ else:
175
+ # Infer from FEN suffix β€” " w" means White to move = PLAYER2
176
+ game.move_count = 1 if " w" in fen.lower() else 0
177
+
178
+ # ── Run MCVS search ───────────────────────────────────────────────────────
179
+ searcher = MCVSSearcher( # noqa: F821
180
+ policy_net=None,
181
+ value_net=None,
182
+ zone_db=zonedb,
183
+ lambda_zone=1.0,
184
+ k_zone=5,
185
+ )
186
  visits, _ = searcher.search_with_time_budget(game, 1.0)
187
 
188
+ # ── Collect legal moves for validation and fallback ───────────────────────
189
+ legal_list = list(game.get_legal_moves())
190
+ legal_set = set(legal_list)
191
+ print(f"[DEBUG] legal moves count: {len(legal_list)}")
192
+ if legal_list:
193
+ print(f"[DEBUG] sample legal move: {legal_list[0]}")
194
  print(f"[DEBUG] visits count: {len(visits)}")
195
  if visits:
196
  print(f"[DEBUG] sample visit key: {next(iter(visits))}")
197
 
198
+ # ── Pick the best visited move that is actually legal ─────────────────────
199
+ # search_with_time_budget should only return legal moves, but we verify
200
+ # to guard against any residual coordinate-system bugs.
201
  best_move = None
202
  if visits:
203
  for candidate in sorted(visits, key=visits.get, reverse=True):
 
205
  best_move = candidate
206
  break
207
  if best_move is None:
208
+ print(f"[WARNING] no visited move is legal; top visited={list(visits.keys())[:5]}")
209
 
210
+ # ── Fallback: first legal move from get_legal_moves() ────────────────────
211
  if best_move is None:
212
+ if legal_list:
213
+ best_move = legal_list[0]
214
  print(f"[WARNING] using first legal move as fallback: {best_move}")
215
  else:
216
+ print("[WARNING] no legal moves in position β€” returning sentinel 0000")
217
+ return "0000"
218
+
219
+ uci = move_to_uci(best_move)
220
+ print(f"[DEBUG] get_move OK: FEN={fen} -> {uci}")
221
+ return uci
222
 
 
 
 
223
 
224
+ # ─────────────────────────────────────────────────────────────────────────────
225
+ # 6. Gradio Interface
226
+ #
227
+ # api_name="get_move" β†’ Gradio 4 SSE queue endpoint:
228
+ # POST /gradio_api/call/get_move body: {"data": [fen, player]}
229
+ # GET /gradio_api/call/get_move/{event_id} (SSE)
230
+ # β†’ event: complete
231
+ # data: ["e4e3"]
232
+ #
233
+ # This matches what predict_breakthrough._try_space_api sends.
234
+ # ─────────────────────────────────────────────────────────────────────────────
235
  demo = gr.Interface(
236
  fn=get_move,
237
+ inputs=[
238
+ gr.Textbox(
239
+ label="Breakthrough FEN",
240
+ placeholder="BBBBBBBB/BBBBBBBB/8/8/8/8/WWWWWWWW/WWWWWWWW w",
 
 
 
 
 
 
 
241
  ),
242
+ gr.Textbox(
243
+ label="Player to move (w / b)",
244
+ placeholder="w",
245
+ value="",
246
+ ),
247
+ ],
248
  outputs=gr.Textbox(label="UCI Move"),
249
+ title="Breakthrough Move Predictor",
250
+ description=(
251
+ "Returns a legal UCI move for the given Breakthrough position. "
252
+ "FEN format: BBBBBBBB/.../WWWWWWWW followed by w or b."
253
+ ),
254
+ api_name="get_move", # β†’ /gradio_api/call/get_move
255
  )
256
 
257
  demo.launch()