Rafs-an09002 commited on
Commit
b6d27d2
·
verified ·
1 Parent(s): 61f7235

Create model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +207 -0
model_loader.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ONNX Model Loader for Synapse-Base
3
+ Handles model loading and inference
4
+ CPU-optimized for HF Spaces
5
+ """
6
+
7
+ import onnxruntime as ort
8
+ import numpy as np
9
+ import chess
10
+ import logging
11
+ from pathlib import Path
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class SynapseModel:
17
+ """ONNX Runtime wrapper for Synapse-Base model"""
18
+
19
+ def __init__(self, model_path: str, num_threads: int = 2):
20
+ """
21
+ Initialize model
22
+
23
+ Args:
24
+ model_path: Path to ONNX model file
25
+ num_threads: Number of CPU threads to use
26
+ """
27
+ self.model_path = Path(model_path)
28
+
29
+ if not self.model_path.exists():
30
+ raise FileNotFoundError(f"Model not found: {model_path}")
31
+
32
+ # ONNX Runtime session options (CPU optimized)
33
+ sess_options = ort.SessionOptions()
34
+ sess_options.intra_op_num_threads = num_threads
35
+ sess_options.inter_op_num_threads = num_threads
36
+ sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
37
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
38
+
39
+ # Create session
40
+ logger.info(f"Loading model from {model_path}...")
41
+ self.session = ort.InferenceSession(
42
+ str(self.model_path),
43
+ sess_options=sess_options,
44
+ providers=['CPUExecutionProvider']
45
+ )
46
+
47
+ # Get input/output names
48
+ self.input_name = self.session.get_inputs()[0].name
49
+ self.output_names = [output.name for output in self.session.get_outputs()]
50
+
51
+ logger.info(f"✅ Model loaded: {self.input_name} -> {self.output_names}")
52
+
53
+ def fen_to_tensor(self, fen: str) -> np.ndarray:
54
+ """
55
+ Convert FEN to 119-channel tensor
56
+
57
+ Args:
58
+ fen: FEN string
59
+
60
+ Returns:
61
+ numpy array of shape (1, 119, 8, 8)
62
+ """
63
+ board = chess.Board(fen)
64
+ tensor = np.zeros((1, 119, 8, 8), dtype=np.float32)
65
+
66
+ # === CHANNELS 0-11: Piece Positions ===
67
+ piece_map = board.piece_map()
68
+ piece_to_channel = {
69
+ chess.PAWN: 0, chess.KNIGHT: 1, chess.BISHOP: 2,
70
+ chess.ROOK: 3, chess.QUEEN: 4, chess.KING: 5
71
+ }
72
+
73
+ for square, piece in piece_map.items():
74
+ rank = square // 8
75
+ file = square % 8
76
+ channel = piece_to_channel[piece.piece_type]
77
+ if piece.color == chess.BLACK:
78
+ channel += 6
79
+ tensor[0, channel, rank, file] = 1.0
80
+
81
+ # === CHANNELS 12-26: Game State Metadata ===
82
+ # Channel 12: Turn (1 = white to move)
83
+ tensor[0, 12, :, :] = 1.0 if board.turn == chess.WHITE else 0.0
84
+
85
+ # Channels 13-16: Castling rights
86
+ tensor[0, 13, :, :] = float(board.has_kingside_castling_rights(chess.WHITE))
87
+ tensor[0, 14, :, :] = float(board.has_queenside_castling_rights(chess.WHITE))
88
+ tensor[0, 15, :, :] = float(board.has_kingside_castling_rights(chess.BLACK))
89
+ tensor[0, 16, :, :] = float(board.has_queenside_castling_rights(chess.BLACK))
90
+
91
+ # Channel 17: En passant square
92
+ if board.ep_square is not None:
93
+ ep_rank = board.ep_square // 8
94
+ ep_file = board.ep_square % 8
95
+ tensor[0, 17, ep_rank, ep_file] = 1.0
96
+
97
+ # Channel 18: Halfmove clock (normalized)
98
+ tensor[0, 18, :, :] = min(board.halfmove_clock / 100.0, 1.0)
99
+
100
+ # Channel 19: Fullmove number (normalized)
101
+ tensor[0, 19, :, :] = min(board.fullmove_number / 100.0, 1.0)
102
+
103
+ # Channels 20-21: Check status
104
+ tensor[0, 20, :, :] = float(board.is_check() and board.turn == chess.WHITE)
105
+ tensor[0, 21, :, :] = float(board.is_check() and board.turn == chess.BLACK)
106
+
107
+ # Channels 22-26: Material count (normalized)
108
+ white_pawns = len(board.pieces(chess.PAWN, chess.WHITE))
109
+ black_pawns = len(board.pieces(chess.PAWN, chess.BLACK))
110
+ tensor[0, 22, :, :] = white_pawns / 8.0
111
+ tensor[0, 23, :, :] = black_pawns / 8.0
112
+
113
+ white_knights = len(board.pieces(chess.KNIGHT, chess.WHITE))
114
+ black_knights = len(board.pieces(chess.KNIGHT, chess.BLACK))
115
+ tensor[0, 24, :, :] = white_knights / 2.0
116
+ tensor[0, 25, :, :] = black_knights / 2.0
117
+
118
+ white_bishops = len(board.pieces(chess.BISHOP, chess.WHITE))
119
+ black_bishops = len(board.pieces(chess.BISHOP, chess.BLACK))
120
+ tensor[0, 26, :, :] = white_bishops / 2.0
121
+
122
+ # === CHANNELS 27-50: Attack Maps ===
123
+ # White attacks
124
+ for square in chess.SQUARES:
125
+ if board.is_attacked_by(chess.WHITE, square):
126
+ rank = square // 8
127
+ file = square % 8
128
+ tensor[0, 27, rank, file] = 1.0
129
+
130
+ # Black attacks
131
+ for square in chess.SQUARES:
132
+ if board.is_attacked_by(chess.BLACK, square):
133
+ rank = square // 8
134
+ file = square % 8
135
+ tensor[0, 28, rank, file] = 1.0
136
+
137
+ # === CHANNELS 51-66: Coordinate Encoding ===
138
+ # Rank encoding
139
+ for rank in range(8):
140
+ tensor[0, 51 + rank, rank, :] = 1.0
141
+
142
+ # File encoding
143
+ for file in range(8):
144
+ tensor[0, 59 + file, :, file] = 1.0
145
+
146
+ # === CHANNELS 67-118: Positional Biases (Static) ===
147
+ # Center control bonus
148
+ center_squares = [chess.D4, chess.D5, chess.E4, chess.E5]
149
+ for square in center_squares:
150
+ rank = square // 8
151
+ file = square % 8
152
+ tensor[0, 67, rank, file] = 0.5
153
+
154
+ # King safety zones
155
+ for color_offset, color in [(0, chess.WHITE), (1, chess.BLACK)]:
156
+ king_square = board.king(color)
157
+ if king_square is not None:
158
+ king_rank = king_square // 8
159
+ king_file = king_square % 8
160
+
161
+ # Mark king zone (3x3 around king)
162
+ for dr in [-1, 0, 1]:
163
+ for df in [-1, 0, 1]:
164
+ r = king_rank + dr
165
+ f = king_file + df
166
+ if 0 <= r < 8 and 0 <= f < 8:
167
+ tensor[0, 68 + color_offset, r, f] = 1.0
168
+
169
+ # Fill remaining channels with zeros (placeholder for future features)
170
+ # Channels 70-118 reserved
171
+
172
+ return tensor
173
+
174
+ def evaluate(self, fen: str) -> dict:
175
+ """
176
+ Evaluate position
177
+
178
+ Args:
179
+ fen: FEN string
180
+
181
+ Returns:
182
+ dict with 'value' and optionally 'policy'
183
+ """
184
+ # Convert FEN to tensor
185
+ input_tensor = self.fen_to_tensor(fen)
186
+
187
+ # Run inference
188
+ outputs = self.session.run(
189
+ self.output_names,
190
+ {self.input_name: input_tensor}
191
+ )
192
+
193
+ # Parse outputs
194
+ result = {}
195
+
196
+ # Value head (always first output)
197
+ result['value'] = float(outputs[0][0][0])
198
+
199
+ # Policy head (if available)
200
+ if len(outputs) > 1:
201
+ result['policy'] = outputs[1][0]
202
+
203
+ return result
204
+
205
+ def get_size_mb(self) -> float:
206
+ """Get model size in MB"""
207
+ return self.model_path.stat().st_size / (1024 * 1024)