Spaces:
Running
Running
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # space_app.py β HF Gradio Space for Breakthrough move prediction | |
| # | |
| # Deploy this file as app.py in your Hugging Face Space. | |
| # | |
| # Coordinate system (identical to apps/games/breakthrough_engine.py): | |
| # rows 0β7, 0-indexed | |
| # row 0 = rank 8 = TOP (Black / 'B' side) | |
| # row 7 = rank 1 = BOTTOM (White / 'W' side) | |
| # cols 0β7 β 'a'β'h' | |
| # rank formula: rank = 8 β row | |
| # | |
| # Piece β MCVS player mapping (matches engine directions): | |
| # 'B' β PLAYER1 = 1, direction = +1 (moves DOWN, row increases toward rank 1) | |
| # 'W' β PLAYER2 = 2, direction = β1 (moves UP, row decreases toward rank 8) | |
| # | |
| # move_to_uci is the exact inverse of breakthrough_engine._sq_to_coords: | |
| # (fr, fc, tr, tc) β cols[fc] + str(8βfr) + cols[tc] + str(8βtr) | |
| # | |
| # Gradio endpoint exposed: | |
| # api_name="get_move" β /gradio_api/call/get_move (Gradio 4 SSE queue) | |
| # inputs: [fen: str, player: str] β POST {"data": [fen, player]} | |
| # output: uci_move: str β SSE data: ["e4e3"] | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import os | |
| import gradio as gr | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| TOKEN = os.environ.get("HF_TOKEN") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1. Load breakthrough_mcvs.py from Hub model repo into this namespace. | |
| # After exec() the following names are available globally: | |
| # Breakthrough, MCVSSearcher, HilbertOrderedZoneDatabase, | |
| # ABCModelDynamic, WeightedMatrixABC, move_to_index, β¦ | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _model_path = hf_hub_download( | |
| "test1978/breakthrough-model", | |
| "breakthrough_mcvs.py", | |
| repo_type="model", | |
| token=TOKEN, | |
| ) | |
| with open(_model_path, "r", encoding="utf-8-sig") as _fh: | |
| exec(_fh.read(), globals()) # noqa: S102 β trusted internal model file | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2. Load the zone database from Hub dataset repo. | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _db_path = hf_hub_download( | |
| "test1978/breakthrough-data", | |
| "breakthrough_zone_db.npz", | |
| repo_type="dataset", | |
| token=TOKEN, | |
| ) | |
| _db_data = np.load(_db_path, allow_pickle=True) | |
| zonedb = HilbertOrderedZoneDatabase() # noqa: F821 β defined by exec above | |
| zonedb.winning_matrices = list(_db_data.get("winning", [])) | |
| zonedb.losing_matrices = list(_db_data.get("losing", [])) | |
| zonedb.draw_matrices = list(_db_data.get("draw", [])) | |
| print( | |
| f"[INIT] Zone DB loaded: " | |
| f"W={len(zonedb.winning_matrices)} " | |
| f"L={len(zonedb.losing_matrices)} " | |
| f"D={len(zonedb.draw_matrices)}" | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3. move_to_uci β exact inverse of breakthrough_engine._sq_to_coords | |
| # | |
| # breakthrough_engine._coords_to_sq(row, col): | |
| # rank_idx = 7 - row # row 0 β rank_idx 7 β rank "8" | |
| # return f"{_FILES[col]}{_RANKS[rank_idx]}" # _RANKS = "12345678" | |
| # βΉ rank = rank_idx + 1 = 8 - row | |
| # | |
| # Therefore: | |
| # move_to_uci((fr, fc, tr, tc)) = cols[fc] + str(8βfr) + cols[tc] + str(8βtr) | |
| # | |
| # Verification: | |
| # (4, 1, 3, 1) β 'b' + str(4) + 'b' + str(5) = 'b4b5' | |
| # bot-runner: 'b4' = _sq_to_coords('b4') = (4,1) β | |
| # 'b5' = _sq_to_coords('b5') = (3,1) β β move (4,1)β(3,1) | |
| # (7, 0, 6, 0) β 'a1a2' (White, row 7β6, rank 1β2) β | |
| # (0, 3, 1, 3) β 'd8d7' (Black, row 0β1, rank 8β7) β | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _COLS = "abcdefgh" | |
| def move_to_uci(move: tuple) -> str: | |
| fr, fc, tr, tc = move | |
| # rank = 8 - row (row 0 = rank 8, row 7 = rank 1) | |
| return _COLS[fc] + str(8 - fr) + _COLS[tc] + str(8 - tr) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4. fen_to_board β FEN string β Breakthrough.board (numpy int32 array, 8Γ8) | |
| # | |
| # FEN example: "BBBBBBBB/BBBBBBBB/8/8/8/8/WWWWWWWW/WWWWWWWW w" | |
| # First rank string = row 0 = rank 8 = top (Black side) | |
| # Last rank string = row 7 = rank 1 = bottom (White side) | |
| # | |
| # Mapping (must match Breakthrough.get_legal_moves() directions): | |
| # 'B' β 1 = PLAYER1 (direction=+1, moves DOWN, row increases) | |
| # 'W' β 2 = PLAYER2 (direction=β1, moves UP, row decreases) | |
| # | |
| # Why BβPLAYER1 and not PLAYER2? | |
| # In breakthrough_mcvs.py, PLAYER1 starts at board[0:2,:] (top rows) with | |
| # direction=+1, so it naturally represents the top-side piece ('B'). | |
| # PLAYER2 starts at board[6:8,:] (bottom rows) with direction=β1, matching 'W'. | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def fen_to_board(fen: str) -> np.ndarray: | |
| ranks_str = fen.strip().split(" ")[0] | |
| rank_parts = ranks_str.split("/") | |
| if len(rank_parts) != 8: | |
| raise ValueError(f"FEN must have 8 ranks separated by '/'; got {len(rank_parts)}: {fen!r}") | |
| board = np.zeros((8, 8), dtype=np.int32) | |
| for r, rank_str in enumerate(rank_parts): | |
| c = 0 | |
| for ch in rank_str: | |
| if ch == "B": | |
| # Black piece at top rows β PLAYER1 (direction=+1 downward) | |
| board[r, c] = 1 | |
| c += 1 | |
| elif ch == "W": | |
| # White piece at bottom rows β PLAYER2 (direction=β1 upward) | |
| board[r, c] = 2 | |
| c += 1 | |
| elif ch.isdigit(): | |
| c += int(ch) # empty squares | |
| else: | |
| raise ValueError(f"Unexpected FEN char {ch!r} in rank {r}: {rank_str!r}") | |
| if c != 8: | |
| raise ValueError(f"Rank {r} has {c} columns, expected 8: {rank_str!r}") | |
| return board | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 5. get_move β main function exposed via Gradio | |
| # | |
| # Parameters | |
| # ---------- | |
| # fen : Breakthrough FEN, e.g. "BBBBBBBB/8/8/8/8/8/8/WWWWWWWW w" | |
| # player : who moves next; accepted values: | |
| # "w" or "2" β White (PLAYER2, move_count = 1 = odd) | |
| # "b" or "1" β Black (PLAYER1, move_count = 0 = even) | |
| # "" or None β infer from FEN side-to-move suffix | |
| # " w" β White (move_count=1) | |
| # " b" β Black (move_count=0) | |
| # | |
| # Returns | |
| # ------- | |
| # uci : str β a legal UCI move string ("e4e3", "b4b5", β¦). | |
| # If search produces no legal move and no legal move exists at all, | |
| # returns "0000" (a safe sentinel the bot-runner recognises as no-op). | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_move(fen: str, player: str = "") -> str: | |
| print(f"[DEBUG] get_move FEN board_text={fen!r}, player={player!r}") | |
| # ββ Build game state ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| game = Breakthrough() # noqa: F821 | |
| game.board = fen_to_board(fen) | |
| game._cached_matrix = None | |
| # Determine which player moves next (move_count parity): | |
| # even β PLAYER1 (Black / 'b'), odd β PLAYER2 (White / 'w') | |
| if player in ("w", "2"): | |
| game.move_count = 1 # White = PLAYER2 = odd move | |
| elif player in ("b", "1"): | |
| game.move_count = 0 # Black = PLAYER1 = even move | |
| else: | |
| # Infer from FEN suffix β " w" means White to move = PLAYER2 | |
| game.move_count = 1 if " w" in fen.lower() else 0 | |
| # ββ Run MCVS search βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| searcher = MCVSSearcher( # noqa: F821 | |
| policy_net=None, | |
| value_net=None, | |
| zone_db=zonedb, | |
| lambda_zone=1.0, | |
| k_zone=5, | |
| ) | |
| visits, _ = searcher.search_with_time_budget(game, 1.0) | |
| # ββ Collect legal moves for validation and fallback βββββββββββββββββββββββ | |
| legal_list = list(game.get_legal_moves()) | |
| legal_set = set(legal_list) | |
| print(f"[DEBUG] legal moves count: {len(legal_list)}") | |
| if legal_list: | |
| print(f"[DEBUG] sample legal move: {legal_list[0]}") | |
| print(f"[DEBUG] visits count: {len(visits)}") | |
| if visits: | |
| print(f"[DEBUG] sample visit key: {next(iter(visits))}") | |
| # ββ Pick the best visited move that is actually legal βββββββββββββββββββββ | |
| # search_with_time_budget should only return legal moves, but we verify | |
| # to guard against any residual coordinate-system bugs. | |
| best_move = None | |
| if visits: | |
| for candidate in sorted(visits, key=visits.get, reverse=True): | |
| if candidate in legal_set: | |
| best_move = candidate | |
| break | |
| if best_move is None: | |
| print(f"[WARNING] no visited move is legal; top visited={list(visits.keys())[:5]}") | |
| # ββ Fallback: first legal move from get_legal_moves() ββββββββββββββββββββ | |
| if best_move is None: | |
| if legal_list: | |
| best_move = legal_list[0] | |
| print(f"[WARNING] using first legal move as fallback: {best_move}") | |
| else: | |
| print("[WARNING] no legal moves in position β returning sentinel 0000") | |
| return "0000" | |
| uci = move_to_uci(best_move) | |
| print(f"[DEBUG] get_move OK: FEN={fen} -> {uci}") | |
| return uci | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 6. Gradio Interface | |
| # | |
| # api_name="get_move" β Gradio 4 SSE queue endpoint: | |
| # POST /gradio_api/call/get_move body: {"data": [fen, player]} | |
| # GET /gradio_api/call/get_move/{event_id} (SSE) | |
| # β event: complete | |
| # data: ["e4e3"] | |
| # | |
| # This matches what predict_breakthrough._try_space_api sends. | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| demo = gr.Interface( | |
| fn=get_move, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Breakthrough FEN", | |
| placeholder="BBBBBBBB/BBBBBBBB/8/8/8/8/WWWWWWWW/WWWWWWWW w", | |
| ), | |
| gr.Textbox( | |
| label="Player to move (w / b)", | |
| placeholder="w", | |
| value="", | |
| ), | |
| ], | |
| outputs=gr.Textbox(label="UCI Move"), | |
| title="Breakthrough Move Predictor", | |
| description=( | |
| "Returns a legal UCI move for the given Breakthrough position. " | |
| "FEN format: BBBBBBBB/.../WWWWWWWW followed by w or b." | |
| ), | |
| api_name="get_move", # β /gradio_api/call/get_move | |
| ) | |
| demo.launch() |