File size: 12,933 Bytes
d0d3ba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9b7dad
 
 
 
 
 
d0d3ba8
 
 
 
 
 
 
 
f1a347e
 
2d3953c
f1a347e
d0d3ba8
 
f1a347e
d0d3ba8
 
 
 
 
f1a347e
 
2d3953c
f1a347e
d0d3ba8
2d3953c
d0d3ba8
 
 
 
f1a347e
d0d3ba8
 
 
 
 
 
1bca279
d0d3ba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1a347e
2d3953c
d0d3ba8
2d3953c
d0d3ba8
 
2d3953c
 
d0d3ba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38d2308
d0d3ba8
38d2308
d0d3ba8
38d2308
d0d3ba8
 
38d2308
 
d0d3ba8
 
38d2308
 
d0d3ba8
38d2308
d0d3ba8
38d2308
d0d3ba8
38d2308
1bca279
d0d3ba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bca279
d0d3ba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bca279
c48475d
d0d3ba8
 
 
 
 
 
c48475d
 
 
 
d0d3ba8
 
 
8259be7
c48475d
8259be7
 
 
 
 
d0d3ba8
8259be7
d0d3ba8
8259be7
d0d3ba8
 
8259be7
 
d0d3ba8
 
 
 
 
 
c48475d
b9b7dad
d0d3ba8
 
 
 
 
 
 
 
 
 
 
b9b7dad
 
d0d3ba8
 
 
 
f1a347e
d0d3ba8
 
 
 
 
 
2d3953c
d0d3ba8
 
 
 
 
 
b9b7dad
1bca279
b9b7dad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
# ──────────────────────────────────────────────────────────────────────────────
# space_app.py β€” HF Gradio Space for Breakthrough move prediction
#
# Deploy this file as app.py in your Hugging Face Space.
#
# Coordinate system (identical to apps/games/breakthrough_engine.py):
#   rows 0–7, 0-indexed
#   row 0 = rank 8 = TOP  (Black / 'B' side)
#   row 7 = rank 1 = BOTTOM (White / 'W' side)
#   cols 0–7 β†’ 'a'–'h'
#   rank formula: rank = 8 βˆ’ row
#
# Piece β†’ MCVS player mapping (matches engine directions):
#   'B' β†’ PLAYER1 = 1, direction = +1 (moves DOWN, row increases toward rank 1)
#   'W' β†’ PLAYER2 = 2, direction = βˆ’1 (moves UP,   row decreases toward rank 8)
#
# move_to_uci is the exact inverse of breakthrough_engine._sq_to_coords:
#   (fr, fc, tr, tc) β†’ cols[fc] + str(8βˆ’fr) + cols[tc] + str(8βˆ’tr)
#
# Gradio endpoint exposed:
#   api_name="get_move" β†’ /gradio_api/call/get_move  (Gradio 4 SSE queue)
#   inputs: [fen: str, player: str]  β†’ POST {"data": [fen, player]}
#   output: uci_move: str            ← SSE  data: ["e4e3"]
# ──────────────────────────────────────────────────────────────────────────────
import os

import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download

TOKEN = os.environ.get("HF_TOKEN")

# ─────────────────────────────────────────────────────────────────────────────
# 1.  Load breakthrough_mcvs.py from Hub model repo into this namespace.
#     After exec() the following names are available globally:
#       Breakthrough, MCVSSearcher, HilbertOrderedZoneDatabase,
#       ABCModelDynamic, WeightedMatrixABC, move_to_index, …
# ─────────────────────────────────────────────────────────────────────────────
_model_path = hf_hub_download(
    "test1978/breakthrough-model",
    "breakthrough_mcvs.py",
    repo_type="model",
    token=TOKEN,
)
with open(_model_path, "r", encoding="utf-8-sig") as _fh:
    exec(_fh.read(), globals())  # noqa: S102 β€” trusted internal model file

# ─────────────────────────────────────────────────────────────────────────────
# 2.  Load the zone database from Hub dataset repo.
# ─────────────────────────────────────────────────────────────────────────────
_db_path = hf_hub_download(
    "test1978/breakthrough-data",
    "breakthrough_zone_db.npz",
    repo_type="dataset",
    token=TOKEN,
)
_db_data = np.load(_db_path, allow_pickle=True)

zonedb = HilbertOrderedZoneDatabase()  # noqa: F821 β€” defined by exec above
zonedb.winning_matrices = list(_db_data.get("winning", []))
zonedb.losing_matrices  = list(_db_data.get("losing",  []))
zonedb.draw_matrices    = list(_db_data.get("draw",    []))

print(
    f"[INIT] Zone DB loaded: "
    f"W={len(zonedb.winning_matrices)} "
    f"L={len(zonedb.losing_matrices)} "
    f"D={len(zonedb.draw_matrices)}"
)

# ─────────────────────────────────────────────────────────────────────────────
# 3.  move_to_uci β€” exact inverse of breakthrough_engine._sq_to_coords
#
# breakthrough_engine._coords_to_sq(row, col):
#     rank_idx = 7 - row          # row 0 β†’ rank_idx 7 β†’ rank "8"
#     return f"{_FILES[col]}{_RANKS[rank_idx]}"   # _RANKS = "12345678"
#   ⟹ rank = rank_idx + 1 = 8 - row
#
# Therefore:
#   move_to_uci((fr, fc, tr, tc))  =  cols[fc] + str(8βˆ’fr) + cols[tc] + str(8βˆ’tr)
#
# Verification:
#   (4, 1, 3, 1) β†’ 'b' + str(4) + 'b' + str(5) = 'b4b5'
#   bot-runner: 'b4' = _sq_to_coords('b4') = (4,1)  βœ“
#               'b5' = _sq_to_coords('b5') = (3,1)  βœ“  β†’ move (4,1)β†’(3,1)
#   (7, 0, 6, 0) β†’ 'a1a2'  (White, row 7β†’6, rank 1β†’2)       βœ“
#   (0, 3, 1, 3) β†’ 'd8d7'  (Black, row 0β†’1, rank 8β†’7)       βœ“
# ─────────────────────────────────────────────────────────────────────────────
_COLS = "abcdefgh"


def move_to_uci(move: tuple) -> str:
    fr, fc, tr, tc = move
    # rank = 8 - row  (row 0 = rank 8, row 7 = rank 1)
    return _COLS[fc] + str(8 - fr) + _COLS[tc] + str(8 - tr)


# ─────────────────────────────────────────────────────────────────────────────
# 4.  fen_to_board β€” FEN string β†’ Breakthrough.board (numpy int32 array, 8Γ—8)
#
# FEN example: "BBBBBBBB/BBBBBBBB/8/8/8/8/WWWWWWWW/WWWWWWWW w"
#   First rank string = row 0 = rank 8 = top (Black side)
#   Last  rank string = row 7 = rank 1 = bottom (White side)
#
# Mapping (must match Breakthrough.get_legal_moves() directions):
#   'B' β†’ 1 = PLAYER1  (direction=+1, moves DOWN, row increases)
#   'W' β†’ 2 = PLAYER2  (direction=βˆ’1, moves UP,   row decreases)
#
# Why B→PLAYER1 and not PLAYER2?
#   In breakthrough_mcvs.py, PLAYER1 starts at board[0:2,:] (top rows) with
#   direction=+1, so it naturally represents the top-side piece ('B').
#   PLAYER2 starts at board[6:8,:] (bottom rows) with direction=βˆ’1, matching 'W'.
# ─────────────────────────────────────────────────────────────────────────────
def fen_to_board(fen: str) -> np.ndarray:
    ranks_str = fen.strip().split(" ")[0]
    rank_parts = ranks_str.split("/")
    if len(rank_parts) != 8:
        raise ValueError(f"FEN must have 8 ranks separated by '/'; got {len(rank_parts)}: {fen!r}")

    board = np.zeros((8, 8), dtype=np.int32)
    for r, rank_str in enumerate(rank_parts):
        c = 0
        for ch in rank_str:
            if ch == "B":
                # Black piece at top rows β†’ PLAYER1 (direction=+1 downward)
                board[r, c] = 1
                c += 1
            elif ch == "W":
                # White piece at bottom rows β†’ PLAYER2 (direction=βˆ’1 upward)
                board[r, c] = 2
                c += 1
            elif ch.isdigit():
                c += int(ch)   # empty squares
            else:
                raise ValueError(f"Unexpected FEN char {ch!r} in rank {r}: {rank_str!r}")
        if c != 8:
            raise ValueError(f"Rank {r} has {c} columns, expected 8: {rank_str!r}")
    return board


# ─────────────────────────────────────────────────────────────────────────────
# 5.  get_move β€” main function exposed via Gradio
#
# Parameters
# ----------
# fen    : Breakthrough FEN, e.g. "BBBBBBBB/8/8/8/8/8/8/WWWWWWWW w"
# player : who moves next; accepted values:
#            "w" or "2"  β†’ White (PLAYER2, move_count = 1 = odd)
#            "b" or "1"  β†’ Black (PLAYER1, move_count = 0 = even)
#            "" or None  β†’ infer from FEN side-to-move suffix
#                          " w" β†’ White (move_count=1)
#                          " b" β†’ Black (move_count=0)
#
# Returns
# -------
# uci : str β€” a legal UCI move string ("e4e3", "b4b5", …).
#   If search produces no legal move and no legal move exists at all,
#   returns "0000" (a safe sentinel the bot-runner recognises as no-op).
# ─────────────────────────────────────────────────────────────────────────────
def get_move(fen: str, player: str = "") -> str:
    print(f"[DEBUG] get_move FEN board_text={fen!r}, player={player!r}")

    # ── Build game state ──────────────────────────────────────────────────────
    game = Breakthrough()             # noqa: F821
    game.board = fen_to_board(fen)
    game._cached_matrix = None

    # Determine which player moves next (move_count parity):
    #   even β†’ PLAYER1 (Black / 'b'), odd β†’ PLAYER2 (White / 'w')
    if player in ("w", "2"):
        game.move_count = 1           # White = PLAYER2 = odd move
    elif player in ("b", "1"):
        game.move_count = 0           # Black = PLAYER1 = even move
    else:
        # Infer from FEN suffix β€” " w" means White to move = PLAYER2
        game.move_count = 1 if " w" in fen.lower() else 0

    # ── Run MCVS search ───────────────────────────────────────────────────────
    searcher = MCVSSearcher(          # noqa: F821
        policy_net=None,
        value_net=None,
        zone_db=zonedb,
        lambda_zone=1.0,
        k_zone=5,
    )
    visits, _ = searcher.search_with_time_budget(game, 1.0)

    # ── Collect legal moves for validation and fallback ───────────────────────
    legal_list = list(game.get_legal_moves())
    legal_set  = set(legal_list)
    print(f"[DEBUG] legal moves count: {len(legal_list)}")
    if legal_list:
        print(f"[DEBUG] sample legal move: {legal_list[0]}")
    print(f"[DEBUG] visits count: {len(visits)}")
    if visits:
        print(f"[DEBUG] sample visit key: {next(iter(visits))}")

    # ── Pick the best visited move that is actually legal ─────────────────────
    # search_with_time_budget should only return legal moves, but we verify
    # to guard against any residual coordinate-system bugs.
    best_move = None
    if visits:
        for candidate in sorted(visits, key=visits.get, reverse=True):
            if candidate in legal_set:
                best_move = candidate
                break
        if best_move is None:
            print(f"[WARNING] no visited move is legal; top visited={list(visits.keys())[:5]}")

    # ── Fallback: first legal move from get_legal_moves() ────────────────────
    if best_move is None:
        if legal_list:
            best_move = legal_list[0]
            print(f"[WARNING] using first legal move as fallback: {best_move}")
        else:
            print("[WARNING] no legal moves in position β€” returning sentinel 0000")
            return "0000"

    uci = move_to_uci(best_move)
    print(f"[DEBUG] get_move OK: FEN={fen} -> {uci}")
    return uci


# ─────────────────────────────────────────────────────────────────────────────
# 6.  Gradio Interface
#
# api_name="get_move"  β†’  Gradio 4 SSE queue endpoint:
#   POST /gradio_api/call/get_move  body: {"data": [fen, player]}
#   GET  /gradio_api/call/get_move/{event_id}  (SSE)
#   β†’ event: complete
#     data: ["e4e3"]
#
# This matches what predict_breakthrough._try_space_api sends.
# ─────────────────────────────────────────────────────────────────────────────
demo = gr.Interface(
    fn=get_move,
    inputs=[
        gr.Textbox(
            label="Breakthrough FEN",
            placeholder="BBBBBBBB/BBBBBBBB/8/8/8/8/WWWWWWWW/WWWWWWWW w",
        ),
        gr.Textbox(
            label="Player to move  (w / b)",
            placeholder="w",
            value="",
        ),
    ],
    outputs=gr.Textbox(label="UCI Move"),
    title="Breakthrough Move Predictor",
    description=(
        "Returns a legal UCI move for the given Breakthrough position. "
        "FEN format: BBBBBBBB/.../WWWWWWWW followed by  w  or  b."
    ),
    api_name="get_move",   # β†’ /gradio_api/call/get_move
)

demo.launch()