| |
|
|
| use rayon::prelude::*; |
|
|
| use crate::board::GameState; |
|
|
| |
| pub struct BoardStates { |
| pub boards: Vec<i8>, |
| pub side_to_move: Vec<bool>, |
| pub castling_rights: Vec<u8>, |
| pub ep_square: Vec<i8>, |
| pub is_check: Vec<bool>, |
| pub halfmove_clock: Vec<u8>, |
| } |
|
|
| |
| |
| pub fn extract_board_states( |
| move_ids: &[i16], |
| game_lengths: &[i16], |
| max_ply: usize, |
| ) -> BoardStates { |
| let batch = game_lengths.len(); |
|
|
| |
| let results: Vec<_> = (0..batch) |
| .into_par_iter() |
| .map(|b| { |
| let length = game_lengths[b] as usize; |
| let mut state = GameState::new(); |
|
|
| let mut boards = vec![0i8; length * 64]; |
| let mut side_to_move = vec![false; length]; |
| let mut castling_rights = vec![0u8; length]; |
| let mut ep_square = vec![-1i8; length]; |
| let mut is_check = vec![false; length]; |
| let mut halfmove_clock = vec![0u8; length]; |
|
|
| for t in 0..length { |
| |
| let board = state.board_array(); |
| for rank in 0..8 { |
| for file in 0..8 { |
| boards[t * 64 + rank * 8 + file] = board[rank][file]; |
| } |
| } |
| side_to_move[t] = state.is_white_to_move(); |
| castling_rights[t] = state.castling_rights_bits(); |
| ep_square[t] = state.ep_square(); |
| is_check[t] = state.is_check(); |
| halfmove_clock[t] = std::cmp::min(state.halfmove_clock(), u8::MAX as u32) as u8; |
|
|
| |
| let token = move_ids[b * max_ply + t] as u16; |
| state.make_move(token).expect("Move should be legal during replay"); |
| } |
|
|
| (boards, side_to_move, castling_rights, ep_square, is_check, halfmove_clock) |
| }) |
| .collect(); |
|
|
| |
| let mut all_boards = vec![0i8; batch * max_ply * 64]; |
| let mut all_stm = vec![false; batch * max_ply]; |
| let mut all_cr = vec![0u8; batch * max_ply]; |
| let mut all_ep = vec![-1i8; batch * max_ply]; |
| let mut all_check = vec![false; batch * max_ply]; |
| let mut all_hmc = vec![0u8; batch * max_ply]; |
|
|
| for (b, (boards, stm, cr, ep, check, hmc)) in results.into_iter().enumerate() { |
| let length = game_lengths[b] as usize; |
| let ply_offset = b * max_ply; |
|
|
| for t in 0..length { |
| let src_offset = t * 64; |
| let dst_offset = (ply_offset + t) * 64; |
| all_boards[dst_offset..dst_offset + 64] |
| .copy_from_slice(&boards[src_offset..src_offset + 64]); |
| all_stm[ply_offset + t] = stm[t]; |
| all_cr[ply_offset + t] = cr[t]; |
| all_ep[ply_offset + t] = ep[t]; |
| all_check[ply_offset + t] = check[t]; |
| all_hmc[ply_offset + t] = hmc[t]; |
| } |
| } |
|
|
| BoardStates { |
| boards: all_boards, |
| side_to_move: all_stm, |
| castling_rights: all_cr, |
| ep_square: all_ep, |
| is_check: all_check, |
| halfmove_clock: all_hmc, |
| } |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
| use crate::batch::generate_random_games; |
|
|
| #[test] |
| fn test_extract_initial_position() { |
| |
| let batch = generate_random_games(1, 256, 42); |
| let states = extract_board_states(&batch.move_ids, &batch.game_lengths, 256); |
|
|
| |
| assert!(states.side_to_move[0]); |
| assert_eq!(states.castling_rights[0], 0b1111); |
| assert_eq!(states.ep_square[0], -1); |
| assert!(!states.is_check[0]); |
|
|
| |
| |
| assert_eq!(states.boards[0], 4); |
| assert_eq!(states.boards[4], 6); |
| } |
| } |
|
|