| |
|
|
| use rand::prelude::*; |
| use rand_chacha::ChaCha8Rng; |
| use rayon::prelude::*; |
|
|
| use crate::board::GameState; |
| use crate::types::Termination; |
|
|
| |
| |
| |
| |
| |
| |
| pub fn derive_game_seeds(base_seed: u64, n: usize) -> Vec<u64> { |
| let mut rng = ChaCha8Rng::seed_from_u64(base_seed); |
| (0..n).map(|_| rng.next_u64()).collect() |
| } |
|
|
| |
| pub struct GameRecord { |
| pub move_ids: Vec<u16>, |
| pub game_length: u16, |
| pub termination: Termination, |
| |
| |
| |
| pub legal_grids: Vec<[u64; 64]>, |
| |
| pub legal_promos: Vec<[[bool; 4]; 44]>, |
| } |
|
|
| |
| |
| |
| pub fn generate_one_game_with_labels(seed: u64, max_ply: usize) -> GameRecord { |
| let mut rng = ChaCha8Rng::seed_from_u64(seed); |
| let mut state = GameState::new(); |
| let mut move_ids = Vec::with_capacity(max_ply); |
| let mut legal_grids = Vec::with_capacity(max_ply); |
| let mut legal_promos = Vec::with_capacity(max_ply); |
|
|
| loop { |
| |
| if let Some(term) = state.check_termination(max_ply) { |
| let game_length = state.ply() as u16; |
| return GameRecord { |
| move_ids, |
| game_length, |
| termination: term, |
| legal_grids, |
| legal_promos, |
| }; |
| } |
|
|
| |
| |
| legal_grids.push(state.legal_move_grid()); |
| legal_promos.push(state.legal_promo_mask()); |
|
|
| |
| let tokens = state.legal_move_tokens(); |
| debug_assert!(!tokens.is_empty(), "No legal moves but termination not detected"); |
|
|
| let chosen = tokens[rng.gen_range(0..tokens.len())]; |
| state.make_move(chosen).unwrap(); |
| move_ids.push(chosen); |
| } |
| } |
|
|
| |
| pub fn generate_one_game(seed: u64, max_ply: usize) -> (Vec<u16>, u16, Termination) { |
| let mut rng = ChaCha8Rng::seed_from_u64(seed); |
| let mut state = GameState::new(); |
| let mut move_ids = Vec::with_capacity(max_ply); |
|
|
| loop { |
| if let Some(term) = state.check_termination(max_ply) { |
| return (move_ids, state.ply() as u16, term); |
| } |
|
|
| let tokens = state.legal_move_tokens(); |
| let chosen = tokens[rng.gen_range(0..tokens.len())]; |
| state.make_move(chosen).unwrap(); |
| move_ids.push(chosen); |
| } |
| } |
|
|
| |
| |
| |
| |
| #[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| #[repr(u8)] |
| pub enum Outcome { |
| WhiteCheckmated = 0, |
| BlackCheckmated = 1, |
| Stalemate = 2, |
| SeventyFiveMoveRule = 3, |
| FivefoldRepetition = 4, |
| InsufficientMaterial = 5, |
| PlyLimit = 6, |
| } |
|
|
| pub const NUM_OUTCOMES: usize = 7; |
|
|
| impl Outcome { |
| |
| pub fn from_termination(term: Termination, white_to_move_at_end: bool) -> Self { |
| match term { |
| Termination::Checkmate => { |
| |
| if white_to_move_at_end { |
| Outcome::WhiteCheckmated |
| } else { |
| Outcome::BlackCheckmated |
| } |
| } |
| Termination::Stalemate => Outcome::Stalemate, |
| Termination::SeventyFiveMoveRule => Outcome::SeventyFiveMoveRule, |
| Termination::FivefoldRepetition => Outcome::FivefoldRepetition, |
| Termination::InsufficientMaterial => Outcome::InsufficientMaterial, |
| Termination::PlyLimit => Outcome::PlyLimit, |
| } |
| } |
| } |
|
|
| |
| #[derive(Debug, Clone, Default)] |
| pub struct OutcomeDistribution { |
| pub counts: [u32; NUM_OUTCOMES], |
| pub total: u32, |
| } |
|
|
| |
| #[derive(Debug, Clone)] |
| pub struct PositionCeiling { |
| |
| pub n_legal: u32, |
| |
| pub unconditional: f64, |
| |
| |
| pub conditional: f64, |
| |
| |
| |
| |
| pub naive_conditional: f64, |
| |
| pub actual_outcome: u8, |
| |
| pub ply: u16, |
| |
| pub game_length: u16, |
| } |
|
|
| |
| |
| |
| |
| pub fn rollout_legal_moves( |
| prefix_tokens: &[u16], |
| n_rollouts: usize, |
| max_ply: usize, |
| base_seed: u64, |
| ) -> Vec<(u16, OutcomeDistribution)> { |
| let state = match GameState::from_move_tokens(prefix_tokens) { |
| Ok(s) => s, |
| Err(_) => return Vec::new(), |
| }; |
|
|
| let legal_tokens = state.legal_move_tokens(); |
| if legal_tokens.is_empty() { |
| return Vec::new(); |
| } |
|
|
| let seeds = derive_game_seeds(base_seed, legal_tokens.len() * n_rollouts); |
|
|
| legal_tokens |
| .iter() |
| .enumerate() |
| .map(|(move_idx, &token)| { |
| let mut dist = OutcomeDistribution::default(); |
| for r in 0..n_rollouts { |
| let seed = seeds[move_idx * n_rollouts + r]; |
| let mut rng = ChaCha8Rng::seed_from_u64(seed); |
| let mut s = state.clone(); |
| s.make_move(token).unwrap(); |
| let term = s.play_random_to_end(&mut rng, max_ply); |
| let outcome = Outcome::from_termination(term, s.is_white_to_move()); |
| dist.counts[outcome as usize] += 1; |
| dist.total += 1; |
| } |
| (token, dist) |
| }) |
| .collect() |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| pub fn compute_accuracy_ceiling( |
| n_games: usize, |
| max_ply: usize, |
| n_rollouts_per_move: usize, |
| sample_rate: f64, |
| base_seed: u64, |
| ) -> Vec<PositionCeiling> { |
| let game_seeds = derive_game_seeds(base_seed, n_games); |
|
|
| |
| let games: Vec<(Vec<u16>, u16, Termination)> = game_seeds |
| .par_iter() |
| .map(|&seed| generate_one_game(seed, max_ply)) |
| .collect(); |
|
|
| |
| let game_outcomes: Vec<Outcome> = games |
| .iter() |
| .map(|(_, game_length, term)| { |
| |
| |
| let white_to_move_at_end = *game_length % 2 == 0; |
| Outcome::from_termination(*term, white_to_move_at_end) |
| }) |
| .collect(); |
|
|
| |
| let mut rng_sample = ChaCha8Rng::seed_from_u64(base_seed.wrapping_add(999)); |
| |
| let mut work_items: Vec<(usize, usize, u8, u16)> = Vec::new(); |
|
|
| for (game_idx, outcome) in game_outcomes.iter().enumerate() { |
| let gl = games[game_idx].1 as usize; |
| let oi = *outcome as u8; |
| for ply in 0..gl { |
| if sample_rate >= 1.0 || rng_sample.gen::<f64>() < sample_rate { |
| work_items.push((game_idx, ply, oi, games[game_idx].1)); |
| } |
| } |
| } |
|
|
| |
| let rollout_seed_base = base_seed.wrapping_add(1_000_000); |
|
|
| work_items |
| .par_iter() |
| .enumerate() |
| .map(|(work_idx, &(game_idx, ply, actual_outcome, game_length))| { |
| let prefix = &games[game_idx].0[..ply]; |
|
|
| |
| let state = GameState::from_move_tokens(prefix).expect("valid prefix"); |
| let legal_tokens = state.legal_move_tokens(); |
| let n_legal = legal_tokens.len() as u32; |
| let unconditional = if n_legal > 0 { 1.0 / n_legal as f64 } else { 0.0 }; |
|
|
| |
| |
| |
| let mut n_wrong_immediate = 0u32; |
| for &token in &legal_tokens { |
| let mut s = state.clone(); |
| s.make_move(token).unwrap(); |
| if let Some(term) = s.check_termination(max_ply) { |
| let move_outcome = Outcome::from_termination(term, s.is_white_to_move()); |
| if move_outcome as u8 != actual_outcome { |
| n_wrong_immediate += 1; |
| } |
| } |
| } |
| let n_remaining = n_legal - n_wrong_immediate; |
| let naive_conditional = if n_remaining > 0 { |
| 1.0 / n_remaining as f64 |
| } else { |
| unconditional |
| }; |
|
|
| |
| let rollout_seed = rollout_seed_base.wrapping_add(work_idx as u64 * 1000); |
| let move_dists = rollout_legal_moves(prefix, n_rollouts_per_move, max_ply, rollout_seed); |
|
|
| let outcome_idx = actual_outcome as usize; |
| let probs: Vec<f64> = move_dists |
| .iter() |
| .map(|(_, dist)| { |
| if dist.total > 0 { |
| dist.counts[outcome_idx] as f64 / dist.total as f64 |
| } else { |
| 0.0 |
| } |
| }) |
| .collect(); |
|
|
| let sum_probs: f64 = probs.iter().sum(); |
| let conditional = if sum_probs > 0.0 { |
| let max_prob = probs.iter().cloned().fold(0.0f64, f64::max); |
| max_prob / sum_probs |
| } else { |
| unconditional |
| }; |
|
|
| PositionCeiling { |
| n_legal, |
| unconditional, |
| conditional, |
| naive_conditional, |
| actual_outcome, |
| ply: ply as u16, |
| game_length, |
| } |
| }) |
| .collect() |
| } |
|
|
| |
| pub struct CheckmateExample { |
| pub move_ids: Vec<u16>, |
| pub game_length: u16, |
| pub checkmate_grid: [u64; 64], |
| pub legal_grid: [u64; 64], |
| } |
|
|
| |
| |
| |
| |
| |
| |
| pub fn generate_checkmate_examples( |
| seed: u64, |
| max_ply: usize, |
| n_target: usize, |
| ) -> (Vec<CheckmateExample>, usize) { |
| let batch_size = 4096usize; |
| let mut collected: Vec<CheckmateExample> = Vec::with_capacity(n_target); |
| let mut total_generated = 0usize; |
| let mut game_seed = seed; |
|
|
| while collected.len() < n_target { |
| let seeds = derive_game_seeds(game_seed, batch_size); |
| |
| let batch: Vec<Option<CheckmateExample>> = seeds |
| .into_par_iter() |
| .map(|s| { |
| let mut rng = ChaCha8Rng::seed_from_u64(s); |
| let mut state = GameState::new(); |
| let mut move_ids = Vec::with_capacity(max_ply); |
|
|
| loop { |
| if let Some(term) = state.check_termination(max_ply) { |
| if term != Termination::Checkmate || move_ids.is_empty() { |
| return None; |
| } |
| let game_length = state.ply() as u16; |
|
|
| |
| let mut replay = GameState::new(); |
| for &tok in &move_ids[..move_ids.len() - 1] { |
| replay.make_move(tok).unwrap(); |
| } |
|
|
| let legal_grid = replay.legal_move_grid(); |
| let legal_tokens = replay.legal_move_tokens(); |
|
|
| |
| let mut checkmate_grid = [0u64; 64]; |
| for &tok in &legal_tokens { |
| let mut test = replay.clone(); |
| test.make_move(tok).unwrap(); |
| if test.check_termination(max_ply + 10) == Some(Termination::Checkmate) { |
| |
| let (src, dst) = crate::vocab::token_to_src_dst(tok); |
| checkmate_grid[src as usize] |= 1u64 << dst; |
| } |
| } |
|
|
| return Some(CheckmateExample { |
| move_ids, |
| game_length, |
| checkmate_grid, |
| legal_grid, |
| }); |
| } |
|
|
| let tokens = state.legal_move_tokens(); |
| let chosen = tokens[rng.gen_range(0..tokens.len())]; |
| state.make_move(chosen).unwrap(); |
| move_ids.push(chosen); |
| } |
| }) |
| .collect(); |
|
|
| game_seed += batch_size as u64; |
| total_generated += batch_size; |
|
|
| for example in batch.into_iter().flatten() { |
| if collected.len() >= n_target { |
| break; |
| } |
| collected.push(example); |
| } |
| } |
|
|
| (collected, total_generated) |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
|
|
| #[test] |
| fn test_generate_game() { |
| let (moves, length, term) = generate_one_game(42, 256); |
| assert_eq!(moves.len(), length as usize); |
| assert!(length > 0); |
| assert!(length <= 256); |
| } |
|
|
| #[test] |
| fn test_generate_game_with_labels() { |
| let record = generate_one_game_with_labels(42, 256); |
| assert_eq!(record.move_ids.len(), record.game_length as usize); |
| assert_eq!(record.legal_grids.len(), record.game_length as usize); |
| assert_eq!(record.legal_promos.len(), record.game_length as usize); |
| } |
|
|
| #[test] |
| fn test_deterministic() { |
| let (m1, l1, t1) = generate_one_game(123, 256); |
| let (m2, l2, t2) = generate_one_game(123, 256); |
| assert_eq!(m1, m2); |
| assert_eq!(l1, l2); |
| assert_eq!(t1, t2); |
| } |
|
|
| #[test] |
| fn test_different_seeds() { |
| let (m1, _, _) = generate_one_game(1, 256); |
| let (m2, _, _) = generate_one_game(2, 256); |
| |
| assert_ne!(m1, m2); |
| } |
| } |
|
|