Spaces:
Running
Running
File size: 12,933 Bytes
d0d3ba8 b9b7dad d0d3ba8 f1a347e 2d3953c f1a347e d0d3ba8 f1a347e d0d3ba8 f1a347e 2d3953c f1a347e d0d3ba8 2d3953c d0d3ba8 f1a347e d0d3ba8 1bca279 d0d3ba8 f1a347e 2d3953c d0d3ba8 2d3953c d0d3ba8 2d3953c d0d3ba8 38d2308 d0d3ba8 38d2308 d0d3ba8 38d2308 d0d3ba8 38d2308 d0d3ba8 38d2308 d0d3ba8 38d2308 d0d3ba8 38d2308 d0d3ba8 38d2308 1bca279 d0d3ba8 1bca279 d0d3ba8 1bca279 c48475d d0d3ba8 c48475d d0d3ba8 8259be7 c48475d 8259be7 d0d3ba8 8259be7 d0d3ba8 8259be7 d0d3ba8 8259be7 d0d3ba8 c48475d b9b7dad d0d3ba8 b9b7dad d0d3ba8 f1a347e d0d3ba8 2d3953c d0d3ba8 b9b7dad 1bca279 b9b7dad | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 | # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# space_app.py β HF Gradio Space for Breakthrough move prediction
#
# Deploy this file as app.py in your Hugging Face Space.
#
# Coordinate system (identical to apps/games/breakthrough_engine.py):
# rows 0β7, 0-indexed
# row 0 = rank 8 = TOP (Black / 'B' side)
# row 7 = rank 1 = BOTTOM (White / 'W' side)
# cols 0β7 β 'a'β'h'
# rank formula: rank = 8 β row
#
# Piece β MCVS player mapping (matches engine directions):
# 'B' β PLAYER1 = 1, direction = +1 (moves DOWN, row increases toward rank 1)
# 'W' β PLAYER2 = 2, direction = β1 (moves UP, row decreases toward rank 8)
#
# move_to_uci is the exact inverse of breakthrough_engine._sq_to_coords:
# (fr, fc, tr, tc) β cols[fc] + str(8βfr) + cols[tc] + str(8βtr)
#
# Gradio endpoint exposed:
# api_name="get_move" β /gradio_api/call/get_move (Gradio 4 SSE queue)
# inputs: [fen: str, player: str] β POST {"data": [fen, player]}
# output: uci_move: str β SSE data: ["e4e3"]
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
import os
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download
TOKEN = os.environ.get("HF_TOKEN")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 1. Load breakthrough_mcvs.py from Hub model repo into this namespace.
# After exec() the following names are available globally:
# Breakthrough, MCVSSearcher, HilbertOrderedZoneDatabase,
# ABCModelDynamic, WeightedMatrixABC, move_to_index, β¦
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
_model_path = hf_hub_download(
"test1978/breakthrough-model",
"breakthrough_mcvs.py",
repo_type="model",
token=TOKEN,
)
with open(_model_path, "r", encoding="utf-8-sig") as _fh:
exec(_fh.read(), globals()) # noqa: S102 β trusted internal model file
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 2. Load the zone database from Hub dataset repo.
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
_db_path = hf_hub_download(
"test1978/breakthrough-data",
"breakthrough_zone_db.npz",
repo_type="dataset",
token=TOKEN,
)
_db_data = np.load(_db_path, allow_pickle=True)
zonedb = HilbertOrderedZoneDatabase() # noqa: F821 β defined by exec above
zonedb.winning_matrices = list(_db_data.get("winning", []))
zonedb.losing_matrices = list(_db_data.get("losing", []))
zonedb.draw_matrices = list(_db_data.get("draw", []))
print(
f"[INIT] Zone DB loaded: "
f"W={len(zonedb.winning_matrices)} "
f"L={len(zonedb.losing_matrices)} "
f"D={len(zonedb.draw_matrices)}"
)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 3. move_to_uci β exact inverse of breakthrough_engine._sq_to_coords
#
# breakthrough_engine._coords_to_sq(row, col):
# rank_idx = 7 - row # row 0 β rank_idx 7 β rank "8"
# return f"{_FILES[col]}{_RANKS[rank_idx]}" # _RANKS = "12345678"
# βΉ rank = rank_idx + 1 = 8 - row
#
# Therefore:
# move_to_uci((fr, fc, tr, tc)) = cols[fc] + str(8βfr) + cols[tc] + str(8βtr)
#
# Verification:
# (4, 1, 3, 1) β 'b' + str(4) + 'b' + str(5) = 'b4b5'
# bot-runner: 'b4' = _sq_to_coords('b4') = (4,1) β
# 'b5' = _sq_to_coords('b5') = (3,1) β β move (4,1)β(3,1)
# (7, 0, 6, 0) β 'a1a2' (White, row 7β6, rank 1β2) β
# (0, 3, 1, 3) β 'd8d7' (Black, row 0β1, rank 8β7) β
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
_COLS = "abcdefgh"
def move_to_uci(move: tuple) -> str:
fr, fc, tr, tc = move
# rank = 8 - row (row 0 = rank 8, row 7 = rank 1)
return _COLS[fc] + str(8 - fr) + _COLS[tc] + str(8 - tr)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 4. fen_to_board β FEN string β Breakthrough.board (numpy int32 array, 8Γ8)
#
# FEN example: "BBBBBBBB/BBBBBBBB/8/8/8/8/WWWWWWWW/WWWWWWWW w"
# First rank string = row 0 = rank 8 = top (Black side)
# Last rank string = row 7 = rank 1 = bottom (White side)
#
# Mapping (must match Breakthrough.get_legal_moves() directions):
# 'B' β 1 = PLAYER1 (direction=+1, moves DOWN, row increases)
# 'W' β 2 = PLAYER2 (direction=β1, moves UP, row decreases)
#
# Why BβPLAYER1 and not PLAYER2?
# In breakthrough_mcvs.py, PLAYER1 starts at board[0:2,:] (top rows) with
# direction=+1, so it naturally represents the top-side piece ('B').
# PLAYER2 starts at board[6:8,:] (bottom rows) with direction=β1, matching 'W'.
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def fen_to_board(fen: str) -> np.ndarray:
ranks_str = fen.strip().split(" ")[0]
rank_parts = ranks_str.split("/")
if len(rank_parts) != 8:
raise ValueError(f"FEN must have 8 ranks separated by '/'; got {len(rank_parts)}: {fen!r}")
board = np.zeros((8, 8), dtype=np.int32)
for r, rank_str in enumerate(rank_parts):
c = 0
for ch in rank_str:
if ch == "B":
# Black piece at top rows β PLAYER1 (direction=+1 downward)
board[r, c] = 1
c += 1
elif ch == "W":
# White piece at bottom rows β PLAYER2 (direction=β1 upward)
board[r, c] = 2
c += 1
elif ch.isdigit():
c += int(ch) # empty squares
else:
raise ValueError(f"Unexpected FEN char {ch!r} in rank {r}: {rank_str!r}")
if c != 8:
raise ValueError(f"Rank {r} has {c} columns, expected 8: {rank_str!r}")
return board
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 5. get_move β main function exposed via Gradio
#
# Parameters
# ----------
# fen : Breakthrough FEN, e.g. "BBBBBBBB/8/8/8/8/8/8/WWWWWWWW w"
# player : who moves next; accepted values:
# "w" or "2" β White (PLAYER2, move_count = 1 = odd)
# "b" or "1" β Black (PLAYER1, move_count = 0 = even)
# "" or None β infer from FEN side-to-move suffix
# " w" β White (move_count=1)
# " b" β Black (move_count=0)
#
# Returns
# -------
# uci : str β a legal UCI move string ("e4e3", "b4b5", β¦).
# If search produces no legal move and no legal move exists at all,
# returns "0000" (a safe sentinel the bot-runner recognises as no-op).
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def get_move(fen: str, player: str = "") -> str:
print(f"[DEBUG] get_move FEN board_text={fen!r}, player={player!r}")
# ββ Build game state ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
game = Breakthrough() # noqa: F821
game.board = fen_to_board(fen)
game._cached_matrix = None
# Determine which player moves next (move_count parity):
# even β PLAYER1 (Black / 'b'), odd β PLAYER2 (White / 'w')
if player in ("w", "2"):
game.move_count = 1 # White = PLAYER2 = odd move
elif player in ("b", "1"):
game.move_count = 0 # Black = PLAYER1 = even move
else:
# Infer from FEN suffix β " w" means White to move = PLAYER2
game.move_count = 1 if " w" in fen.lower() else 0
# ββ Run MCVS search βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
searcher = MCVSSearcher( # noqa: F821
policy_net=None,
value_net=None,
zone_db=zonedb,
lambda_zone=1.0,
k_zone=5,
)
visits, _ = searcher.search_with_time_budget(game, 1.0)
# ββ Collect legal moves for validation and fallback βββββββββββββββββββββββ
legal_list = list(game.get_legal_moves())
legal_set = set(legal_list)
print(f"[DEBUG] legal moves count: {len(legal_list)}")
if legal_list:
print(f"[DEBUG] sample legal move: {legal_list[0]}")
print(f"[DEBUG] visits count: {len(visits)}")
if visits:
print(f"[DEBUG] sample visit key: {next(iter(visits))}")
# ββ Pick the best visited move that is actually legal βββββββββββββββββββββ
# search_with_time_budget should only return legal moves, but we verify
# to guard against any residual coordinate-system bugs.
best_move = None
if visits:
for candidate in sorted(visits, key=visits.get, reverse=True):
if candidate in legal_set:
best_move = candidate
break
if best_move is None:
print(f"[WARNING] no visited move is legal; top visited={list(visits.keys())[:5]}")
# ββ Fallback: first legal move from get_legal_moves() ββββββββββββββββββββ
if best_move is None:
if legal_list:
best_move = legal_list[0]
print(f"[WARNING] using first legal move as fallback: {best_move}")
else:
print("[WARNING] no legal moves in position β returning sentinel 0000")
return "0000"
uci = move_to_uci(best_move)
print(f"[DEBUG] get_move OK: FEN={fen} -> {uci}")
return uci
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 6. Gradio Interface
#
# api_name="get_move" β Gradio 4 SSE queue endpoint:
# POST /gradio_api/call/get_move body: {"data": [fen, player]}
# GET /gradio_api/call/get_move/{event_id} (SSE)
# β event: complete
# data: ["e4e3"]
#
# This matches what predict_breakthrough._try_space_api sends.
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
demo = gr.Interface(
fn=get_move,
inputs=[
gr.Textbox(
label="Breakthrough FEN",
placeholder="BBBBBBBB/BBBBBBBB/8/8/8/8/WWWWWWWW/WWWWWWWW w",
),
gr.Textbox(
label="Player to move (w / b)",
placeholder="w",
value="",
),
],
outputs=gr.Textbox(label="UCI Move"),
title="Breakthrough Move Predictor",
description=(
"Returns a legal UCI move for the given Breakthrough position. "
"FEN format: BBBBBBBB/.../WWWWWWWW followed by w or b."
),
api_name="get_move", # β /gradio_api/call/get_move
)
demo.launch() |