| |
| |
| |
| |
|
|
| use rand_chacha::ChaCha8Rng; |
| use rand::SeedableRng; |
|
|
| use crate::board::GameState; |
| use crate::types::Termination; |
|
|
| |
| struct GameMeta { |
| agent_is_white: bool, |
| terminated: bool, |
| forfeited: bool, |
| outcome_reward: f32, |
| agent_plies: u32, |
| termination_code: i8, |
| } |
|
|
| impl GameMeta { |
| fn new(agent_is_white: bool) -> Self { |
| Self { |
| agent_is_white, |
| terminated: false, |
| forfeited: false, |
| outcome_reward: 0.0, |
| agent_plies: 0, |
| termination_code: -1, |
| } |
| } |
| } |
|
|
| |
| pub struct BatchRLEnv { |
| games: Vec<GameState>, |
| meta: Vec<GameMeta>, |
| n_games: usize, |
| max_ply: usize, |
| rng: ChaCha8Rng, |
| } |
|
|
| impl BatchRLEnv { |
| pub fn new(n_games: usize, max_ply: usize, seed: u64) -> Self { |
| Self { |
| games: Vec::with_capacity(n_games), |
| meta: Vec::with_capacity(n_games), |
| n_games, |
| max_ply, |
| rng: ChaCha8Rng::seed_from_u64(seed), |
| } |
| } |
|
|
| |
| pub fn reset(&mut self) { |
| self.games.clear(); |
| self.meta.clear(); |
| let n_white = self.n_games / 2; |
| for i in 0..self.n_games { |
| self.games.push(GameState::new()); |
| self.meta.push(GameMeta::new(i < n_white)); |
| } |
| } |
|
|
| pub fn n_games(&self) -> usize { |
| self.n_games |
| } |
|
|
| pub fn all_terminated(&self) -> bool { |
| self.meta.iter().all(|m| m.terminated) |
| } |
|
|
| |
| pub fn active_agent_games(&self) -> Vec<u32> { |
| (0..self.n_games) |
| .filter(|&i| { |
| let m = &self.meta[i]; |
| if m.terminated { |
| return false; |
| } |
| let white_to_move = self.games[i].is_white_to_move(); |
| white_to_move == m.agent_is_white |
| }) |
| .map(|i| i as u32) |
| .collect() |
| } |
|
|
| |
| pub fn active_opponent_games(&self) -> Vec<u32> { |
| (0..self.n_games) |
| .filter(|&i| { |
| let m = &self.meta[i]; |
| if m.terminated { |
| return false; |
| } |
| let white_to_move = self.games[i].is_white_to_move(); |
| white_to_move != m.agent_is_white |
| }) |
| .map(|i| i as u32) |
| .collect() |
| } |
|
|
| |
| |
| |
|
|
| fn finalize(&mut self, gi: usize) { |
| let m = &mut self.meta[gi]; |
| m.terminated = true; |
|
|
| if let Some(term) = self.games[gi].check_termination(self.max_ply) { |
| m.termination_code = term.as_u8() as i8; |
| match term { |
| Termination::Checkmate => { |
| |
| let loser_is_white = self.games[gi].is_white_to_move(); |
| if loser_is_white == m.agent_is_white { |
| m.outcome_reward = -1.0; |
| } else { |
| m.outcome_reward = 1.0; |
| } |
| } |
| |
| _ => { |
| m.outcome_reward = 0.0; |
| } |
| } |
| } |
| } |
|
|
| |
| |
| |
|
|
| |
| |
| |
| pub fn apply_moves(&mut self, game_indices: &[u32], tokens: &[u16]) -> (Vec<bool>, Vec<i8>) { |
| let mut flags = Vec::with_capacity(game_indices.len()); |
| let mut term_codes = Vec::with_capacity(game_indices.len()); |
| for (&gi, &token) in game_indices.iter().zip(tokens.iter()) { |
| let gi = gi as usize; |
| if self.meta[gi].terminated { |
| flags.push(false); |
| term_codes.push(self.meta[gi].termination_code); |
| continue; |
| } |
|
|
| if self.games[gi].make_move(token).is_err() { |
| self.meta[gi].terminated = true; |
| self.meta[gi].forfeited = true; |
| self.meta[gi].termination_code = -3; |
| flags.push(false); |
| term_codes.push(-3); |
| continue; |
| } |
|
|
| flags.push(true); |
|
|
| if let Some(term) = self.games[gi].check_termination(self.max_ply) { |
| self.meta[gi].terminated = true; |
| self.meta[gi].termination_code = term.as_u8() as i8; |
| term_codes.push(term.as_u8() as i8); |
| } else { |
| term_codes.push(-1); |
| } |
| } |
| (flags, term_codes) |
| } |
|
|
| |
| |
| |
| pub fn load_prefixes(&mut self, move_ids: &[u16], lengths: &[u32], n_games: usize, max_ply: usize) -> Vec<i8> { |
| let mut term_codes = Vec::with_capacity(n_games); |
| for gi in 0..n_games { |
| let len = lengths[gi] as usize; |
| let mut tc: i8 = -1; |
| for t in 0..len { |
| let token = move_ids[gi * max_ply + t]; |
| if self.games[gi].make_move(token).is_err() { |
| self.meta[gi].terminated = true; |
| self.meta[gi].forfeited = true; |
| self.meta[gi].termination_code = -3; |
| tc = -3; |
| break; |
| } |
| if let Some(term) = self.games[gi].check_termination(self.max_ply) { |
| self.meta[gi].terminated = true; |
| self.meta[gi].termination_code = term.as_u8() as i8; |
| tc = term.as_u8() as i8; |
| break; |
| } |
| } |
| term_codes.push(tc); |
| } |
| term_codes |
| } |
|
|
| |
| |
| |
|
|
| |
| pub fn apply_agent_moves(&mut self, game_indices: &[u32], tokens: &[u16]) -> Vec<bool> { |
| let mut flags = Vec::with_capacity(game_indices.len()); |
| for (&gi, &token) in game_indices.iter().zip(tokens.iter()) { |
| let gi = gi as usize; |
| if self.meta[gi].terminated { |
| flags.push(false); |
| continue; |
| } |
|
|
| if self.games[gi].make_move(token).is_err() { |
| self.meta[gi].terminated = true; |
| self.meta[gi].forfeited = true; |
| self.meta[gi].outcome_reward = -1.0; |
| flags.push(false); |
| continue; |
| } |
|
|
| flags.push(true); |
| self.meta[gi].agent_plies += 1; |
|
|
| if self.games[gi].check_termination(self.max_ply).is_some() { |
| self.finalize(gi); |
| } |
| } |
| flags |
| } |
|
|
| |
| |
| |
|
|
| |
| |
| pub fn apply_random_opponent_moves(&mut self) -> Vec<u32> { |
| let opp_games = self.active_opponent_games(); |
| for &gi in &opp_games { |
| let gi = gi as usize; |
| self.games[gi].make_random_move(&mut self.rng); |
| if self.games[gi].check_termination(self.max_ply).is_some() { |
| self.finalize(gi); |
| } |
| } |
| opp_games |
| } |
|
|
| |
| |
| |
|
|
| |
| |
| pub fn apply_opponent_moves(&mut self, game_indices: &[u32], tokens: &[u16]) -> Vec<bool> { |
| let mut flags = Vec::with_capacity(game_indices.len()); |
| for (&gi, &token) in game_indices.iter().zip(tokens.iter()) { |
| let gi = gi as usize; |
| if self.meta[gi].terminated { |
| flags.push(false); |
| continue; |
| } |
|
|
| if self.games[gi].make_move(token).is_err() { |
| |
| self.meta[gi].terminated = true; |
| self.meta[gi].outcome_reward = 0.0; |
| flags.push(false); |
| continue; |
| } |
|
|
| flags.push(true); |
|
|
| if self.games[gi].check_termination(self.max_ply).is_some() { |
| self.finalize(gi); |
| } |
| } |
| flags |
| } |
|
|
| |
| |
| |
|
|
| |
| |
| pub fn get_legal_token_masks_batch( |
| &self, |
| game_indices: &[u32], |
| vocab_size: usize, |
| ) -> Vec<bool> { |
| let b = game_indices.len(); |
| let mut masks = vec![false; b * vocab_size]; |
| for (bi, &gi) in game_indices.iter().enumerate() { |
| for tok in self.games[gi as usize].legal_move_tokens() { |
| let idx = bi * vocab_size + tok as usize; |
| if idx < masks.len() { |
| masks[idx] = true; |
| } |
| } |
| } |
| masks |
| } |
|
|
| |
| |
| |
| pub fn get_legal_moves_batch( |
| &self, |
| game_indices: &[u32], |
| ) -> (Vec<(Vec<u16>, Vec<(u16, Vec<u8>)>)>, Vec<bool>) { |
| let b = game_indices.len(); |
| let mut structured = Vec::with_capacity(b); |
| let mut flat_masks = vec![false; b * 4096]; |
|
|
| for (bi, &gi) in game_indices.iter().enumerate() { |
| let (indices, promos, mask) = self.games[gi as usize].legal_moves_full(); |
| structured.push((indices, promos)); |
| flat_masks[bi * 4096..(bi + 1) * 4096].copy_from_slice(&mask); |
| } |
|
|
| (structured, flat_masks) |
| } |
|
|
| |
| pub fn get_move_histories(&self, game_indices: &[u32]) -> (Vec<i64>, Vec<i32>) { |
| let b = game_indices.len(); |
| let mut flat = vec![0i64; b * self.max_ply]; |
| let mut lengths = Vec::with_capacity(b); |
|
|
| for (bi, &gi) in game_indices.iter().enumerate() { |
| let hist = self.games[gi as usize].move_history(); |
| let len = hist.len().min(self.max_ply); |
| for t in 0..len { |
| flat[bi * self.max_ply + t] = hist[t] as i64; |
| } |
| lengths.push(len as i32); |
| } |
|
|
| (flat, lengths) |
| } |
|
|
| |
| pub fn get_sentinel_tokens(&self, game_indices: &[u32]) -> Vec<u16> { |
| game_indices |
| .iter() |
| .map(|&gi| { |
| let tokens = self.games[gi as usize].legal_move_tokens(); |
| if tokens.is_empty() { 1 } else { tokens[0] } |
| }) |
| .collect() |
| } |
|
|
| |
| pub fn get_fens(&self, game_indices: &[u32]) -> Vec<String> { |
| game_indices |
| .iter() |
| .map(|&gi| self.games[gi as usize].fen()) |
| .collect() |
| } |
|
|
| |
| pub fn get_uci_positions(&self, game_indices: &[u32]) -> Vec<String> { |
| game_indices |
| .iter() |
| .map(|&gi| self.games[gi as usize].uci_position_string()) |
| .collect() |
| } |
|
|
| |
| pub fn get_plies(&self, game_indices: &[u32]) -> Vec<u32> { |
| game_indices |
| .iter() |
| .map(|&gi| self.games[gi as usize].ply() as u32) |
| .collect() |
| } |
|
|
| |
| |
| |
| pub fn get_outcomes(&self) -> (Vec<bool>, Vec<bool>, Vec<f32>, Vec<u32>, Vec<i8>, Vec<bool>) { |
| let mut terminated = Vec::with_capacity(self.n_games); |
| let mut forfeited = Vec::with_capacity(self.n_games); |
| let mut rewards = Vec::with_capacity(self.n_games); |
| let mut plies = Vec::with_capacity(self.n_games); |
| let mut codes = Vec::with_capacity(self.n_games); |
| let mut colors = Vec::with_capacity(self.n_games); |
|
|
| for m in &self.meta { |
| terminated.push(m.terminated); |
| forfeited.push(m.forfeited); |
| rewards.push(m.outcome_reward); |
| plies.push(m.agent_plies); |
| codes.push(m.termination_code); |
| colors.push(m.agent_is_white); |
| } |
|
|
| (terminated, forfeited, rewards, plies, codes, colors) |
| } |
|
|
| |
| pub fn game(&self, gi: usize) -> &GameState { |
| &self.games[gi] |
| } |
|
|
| pub fn meta(&self, gi: usize) -> (bool, bool, bool, f32, u32, i8) { |
| let m = &self.meta[gi]; |
| ( |
| m.agent_is_white, |
| m.terminated, |
| m.forfeited, |
| m.outcome_reward, |
| m.agent_plies, |
| m.termination_code, |
| ) |
| } |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
|
|
| #[test] |
| fn test_reset_and_active_games() { |
| let mut env = BatchRLEnv::new(4, 256, 42); |
| env.reset(); |
| assert_eq!(env.n_games(), 4); |
| assert!(!env.all_terminated()); |
|
|
| |
| |
| let agent = env.active_agent_games(); |
| let opp = env.active_opponent_games(); |
| assert_eq!(agent, vec![0, 1]); |
| assert_eq!(opp, vec![2, 3]); |
| } |
|
|
| #[test] |
| fn test_random_opponent_moves() { |
| let mut env = BatchRLEnv::new(4, 256, 42); |
| env.reset(); |
|
|
| |
| let acted = env.apply_random_opponent_moves(); |
| assert_eq!(acted, vec![2, 3]); |
|
|
| |
| let agent = env.active_agent_games(); |
| assert_eq!(agent, vec![0, 1, 2, 3]); |
| } |
|
|
| #[test] |
| fn test_fen_export() { |
| let mut env = BatchRLEnv::new(1, 256, 42); |
| env.reset(); |
| let fens = env.get_fens(&[0]); |
| assert!(fens[0].contains("rnbqkbnr")); |
| } |
|
|
| #[test] |
| fn test_legal_moves_batch() { |
| let mut env = BatchRLEnv::new(2, 256, 42); |
| env.reset(); |
| let (structured, masks) = env.get_legal_moves_batch(&[0, 1]); |
| assert_eq!(structured.len(), 2); |
| assert_eq!(masks.len(), 2 * 4096); |
| |
| assert_eq!(structured[0].0.len(), 20); |
| } |
| } |
|
|