test1978's picture
update app.py
d0d3ba8 verified
# ──────────────────────────────────────────────────────────────────────────────
# space_app.py β€” HF Gradio Space for Breakthrough move prediction
#
# Deploy this file as app.py in your Hugging Face Space.
#
# Coordinate system (identical to apps/games/breakthrough_engine.py):
# rows 0–7, 0-indexed
# row 0 = rank 8 = TOP (Black / 'B' side)
# row 7 = rank 1 = BOTTOM (White / 'W' side)
# cols 0–7 β†’ 'a'–'h'
# rank formula: rank = 8 βˆ’ row
#
# Piece β†’ MCVS player mapping (matches engine directions):
# 'B' β†’ PLAYER1 = 1, direction = +1 (moves DOWN, row increases toward rank 1)
# 'W' β†’ PLAYER2 = 2, direction = βˆ’1 (moves UP, row decreases toward rank 8)
#
# move_to_uci is the exact inverse of breakthrough_engine._sq_to_coords:
# (fr, fc, tr, tc) β†’ cols[fc] + str(8βˆ’fr) + cols[tc] + str(8βˆ’tr)
#
# Gradio endpoint exposed:
# api_name="get_move" β†’ /gradio_api/call/get_move (Gradio 4 SSE queue)
# inputs: [fen: str, player: str] β†’ POST {"data": [fen, player]}
# output: uci_move: str ← SSE data: ["e4e3"]
# ──────────────────────────────────────────────────────────────────────────────
import os
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download
TOKEN = os.environ.get("HF_TOKEN")
# ─────────────────────────────────────────────────────────────────────────────
# 1. Load breakthrough_mcvs.py from Hub model repo into this namespace.
# After exec() the following names are available globally:
# Breakthrough, MCVSSearcher, HilbertOrderedZoneDatabase,
# ABCModelDynamic, WeightedMatrixABC, move_to_index, …
# ─────────────────────────────────────────────────────────────────────────────
_model_path = hf_hub_download(
"test1978/breakthrough-model",
"breakthrough_mcvs.py",
repo_type="model",
token=TOKEN,
)
with open(_model_path, "r", encoding="utf-8-sig") as _fh:
exec(_fh.read(), globals()) # noqa: S102 β€” trusted internal model file
# ─────────────────────────────────────────────────────────────────────────────
# 2. Load the zone database from Hub dataset repo.
# ─────────────────────────────────────────────────────────────────────────────
_db_path = hf_hub_download(
"test1978/breakthrough-data",
"breakthrough_zone_db.npz",
repo_type="dataset",
token=TOKEN,
)
_db_data = np.load(_db_path, allow_pickle=True)
zonedb = HilbertOrderedZoneDatabase() # noqa: F821 β€” defined by exec above
zonedb.winning_matrices = list(_db_data.get("winning", []))
zonedb.losing_matrices = list(_db_data.get("losing", []))
zonedb.draw_matrices = list(_db_data.get("draw", []))
print(
f"[INIT] Zone DB loaded: "
f"W={len(zonedb.winning_matrices)} "
f"L={len(zonedb.losing_matrices)} "
f"D={len(zonedb.draw_matrices)}"
)
# ─────────────────────────────────────────────────────────────────────────────
# 3. move_to_uci β€” exact inverse of breakthrough_engine._sq_to_coords
#
# breakthrough_engine._coords_to_sq(row, col):
# rank_idx = 7 - row # row 0 β†’ rank_idx 7 β†’ rank "8"
# return f"{_FILES[col]}{_RANKS[rank_idx]}" # _RANKS = "12345678"
# ⟹ rank = rank_idx + 1 = 8 - row
#
# Therefore:
# move_to_uci((fr, fc, tr, tc)) = cols[fc] + str(8βˆ’fr) + cols[tc] + str(8βˆ’tr)
#
# Verification:
# (4, 1, 3, 1) β†’ 'b' + str(4) + 'b' + str(5) = 'b4b5'
# bot-runner: 'b4' = _sq_to_coords('b4') = (4,1) βœ“
# 'b5' = _sq_to_coords('b5') = (3,1) βœ“ β†’ move (4,1)β†’(3,1)
# (7, 0, 6, 0) β†’ 'a1a2' (White, row 7β†’6, rank 1β†’2) βœ“
# (0, 3, 1, 3) β†’ 'd8d7' (Black, row 0β†’1, rank 8β†’7) βœ“
# ─────────────────────────────────────────────────────────────────────────────
_COLS = "abcdefgh"
def move_to_uci(move: tuple) -> str:
fr, fc, tr, tc = move
# rank = 8 - row (row 0 = rank 8, row 7 = rank 1)
return _COLS[fc] + str(8 - fr) + _COLS[tc] + str(8 - tr)
# ─────────────────────────────────────────────────────────────────────────────
# 4. fen_to_board β€” FEN string β†’ Breakthrough.board (numpy int32 array, 8Γ—8)
#
# FEN example: "BBBBBBBB/BBBBBBBB/8/8/8/8/WWWWWWWW/WWWWWWWW w"
# First rank string = row 0 = rank 8 = top (Black side)
# Last rank string = row 7 = rank 1 = bottom (White side)
#
# Mapping (must match Breakthrough.get_legal_moves() directions):
# 'B' β†’ 1 = PLAYER1 (direction=+1, moves DOWN, row increases)
# 'W' β†’ 2 = PLAYER2 (direction=βˆ’1, moves UP, row decreases)
#
# Why B→PLAYER1 and not PLAYER2?
# In breakthrough_mcvs.py, PLAYER1 starts at board[0:2,:] (top rows) with
# direction=+1, so it naturally represents the top-side piece ('B').
# PLAYER2 starts at board[6:8,:] (bottom rows) with direction=βˆ’1, matching 'W'.
# ─────────────────────────────────────────────────────────────────────────────
def fen_to_board(fen: str) -> np.ndarray:
ranks_str = fen.strip().split(" ")[0]
rank_parts = ranks_str.split("/")
if len(rank_parts) != 8:
raise ValueError(f"FEN must have 8 ranks separated by '/'; got {len(rank_parts)}: {fen!r}")
board = np.zeros((8, 8), dtype=np.int32)
for r, rank_str in enumerate(rank_parts):
c = 0
for ch in rank_str:
if ch == "B":
# Black piece at top rows β†’ PLAYER1 (direction=+1 downward)
board[r, c] = 1
c += 1
elif ch == "W":
# White piece at bottom rows β†’ PLAYER2 (direction=βˆ’1 upward)
board[r, c] = 2
c += 1
elif ch.isdigit():
c += int(ch) # empty squares
else:
raise ValueError(f"Unexpected FEN char {ch!r} in rank {r}: {rank_str!r}")
if c != 8:
raise ValueError(f"Rank {r} has {c} columns, expected 8: {rank_str!r}")
return board
# ─────────────────────────────────────────────────────────────────────────────
# 5. get_move β€” main function exposed via Gradio
#
# Parameters
# ----------
# fen : Breakthrough FEN, e.g. "BBBBBBBB/8/8/8/8/8/8/WWWWWWWW w"
# player : who moves next; accepted values:
# "w" or "2" β†’ White (PLAYER2, move_count = 1 = odd)
# "b" or "1" β†’ Black (PLAYER1, move_count = 0 = even)
# "" or None β†’ infer from FEN side-to-move suffix
# " w" β†’ White (move_count=1)
# " b" β†’ Black (move_count=0)
#
# Returns
# -------
# uci : str β€” a legal UCI move string ("e4e3", "b4b5", …).
# If search produces no legal move and no legal move exists at all,
# returns "0000" (a safe sentinel the bot-runner recognises as no-op).
# ─────────────────────────────────────────────────────────────────────────────
def get_move(fen: str, player: str = "") -> str:
print(f"[DEBUG] get_move FEN board_text={fen!r}, player={player!r}")
# ── Build game state ──────────────────────────────────────────────────────
game = Breakthrough() # noqa: F821
game.board = fen_to_board(fen)
game._cached_matrix = None
# Determine which player moves next (move_count parity):
# even β†’ PLAYER1 (Black / 'b'), odd β†’ PLAYER2 (White / 'w')
if player in ("w", "2"):
game.move_count = 1 # White = PLAYER2 = odd move
elif player in ("b", "1"):
game.move_count = 0 # Black = PLAYER1 = even move
else:
# Infer from FEN suffix β€” " w" means White to move = PLAYER2
game.move_count = 1 if " w" in fen.lower() else 0
# ── Run MCVS search ───────────────────────────────────────────────────────
searcher = MCVSSearcher( # noqa: F821
policy_net=None,
value_net=None,
zone_db=zonedb,
lambda_zone=1.0,
k_zone=5,
)
visits, _ = searcher.search_with_time_budget(game, 1.0)
# ── Collect legal moves for validation and fallback ───────────────────────
legal_list = list(game.get_legal_moves())
legal_set = set(legal_list)
print(f"[DEBUG] legal moves count: {len(legal_list)}")
if legal_list:
print(f"[DEBUG] sample legal move: {legal_list[0]}")
print(f"[DEBUG] visits count: {len(visits)}")
if visits:
print(f"[DEBUG] sample visit key: {next(iter(visits))}")
# ── Pick the best visited move that is actually legal ─────────────────────
# search_with_time_budget should only return legal moves, but we verify
# to guard against any residual coordinate-system bugs.
best_move = None
if visits:
for candidate in sorted(visits, key=visits.get, reverse=True):
if candidate in legal_set:
best_move = candidate
break
if best_move is None:
print(f"[WARNING] no visited move is legal; top visited={list(visits.keys())[:5]}")
# ── Fallback: first legal move from get_legal_moves() ────────────────────
if best_move is None:
if legal_list:
best_move = legal_list[0]
print(f"[WARNING] using first legal move as fallback: {best_move}")
else:
print("[WARNING] no legal moves in position β€” returning sentinel 0000")
return "0000"
uci = move_to_uci(best_move)
print(f"[DEBUG] get_move OK: FEN={fen} -> {uci}")
return uci
# ─────────────────────────────────────────────────────────────────────────────
# 6. Gradio Interface
#
# api_name="get_move" β†’ Gradio 4 SSE queue endpoint:
# POST /gradio_api/call/get_move body: {"data": [fen, player]}
# GET /gradio_api/call/get_move/{event_id} (SSE)
# β†’ event: complete
# data: ["e4e3"]
#
# This matches what predict_breakthrough._try_space_api sends.
# ─────────────────────────────────────────────────────────────────────────────
demo = gr.Interface(
fn=get_move,
inputs=[
gr.Textbox(
label="Breakthrough FEN",
placeholder="BBBBBBBB/BBBBBBBB/8/8/8/8/WWWWWWWW/WWWWWWWW w",
),
gr.Textbox(
label="Player to move (w / b)",
placeholder="w",
value="",
),
],
outputs=gr.Textbox(label="UCI Move"),
title="Breakthrough Move Predictor",
description=(
"Returns a legal UCI move for the given Breakthrough position. "
"FEN format: BBBBBBBB/.../WWWWWWWW followed by w or b."
),
api_name="get_move", # β†’ /gradio_api/call/get_move
)
demo.launch()