| |
|
|
| use rayon::prelude::*; |
|
|
| use crate::random::{derive_game_seeds, generate_one_game, generate_one_game_with_labels, generate_checkmate_examples, GameRecord}; |
| use crate::types::Termination; |
| use crate::vocab; |
|
|
| |
| pub struct TrainingBatch { |
| pub move_ids: Vec<i16>, |
| pub game_lengths: Vec<i16>, |
| pub legal_move_grid: Vec<u64>, |
| pub legal_promo_mask: Vec<bool>, |
| pub termination_codes: Vec<u8>, |
| pub batch_size: usize, |
| pub max_ply: usize, |
| } |
|
|
| |
| pub struct GameBatch { |
| pub move_ids: Vec<i16>, |
| pub game_lengths: Vec<i16>, |
| pub termination_codes: Vec<u8>, |
| pub n_games: usize, |
| pub max_ply: usize, |
| } |
|
|
| |
| |
| pub fn generate_training_batch(batch_size: usize, max_ply: usize, seed: u64) -> TrainingBatch { |
| |
| let seeds = derive_game_seeds(seed, batch_size); |
| let records: Vec<GameRecord> = seeds |
| .into_par_iter() |
| .map(|s| generate_one_game_with_labels(s, max_ply)) |
| .collect(); |
|
|
| |
| let total_ply = batch_size * max_ply; |
| let mut move_ids = vec![0i16; total_ply]; |
| let mut game_lengths = Vec::with_capacity(batch_size); |
| let mut legal_move_grid = vec![0u64; total_ply * 64]; |
| let mut legal_promo_mask = vec![false; total_ply * 44 * 4]; |
| let mut termination_codes = Vec::with_capacity(batch_size); |
|
|
| for (b, record) in records.iter().enumerate() { |
| let length = record.game_length as usize; |
| game_lengths.push(record.game_length as i16); |
| termination_codes.push(record.termination.as_u8()); |
|
|
| |
| for t in 0..length { |
| move_ids[b * max_ply + t] = record.move_ids[t] as i16; |
| } |
|
|
| |
| for t in 0..length { |
| let grid_offset = (b * max_ply + t) * 64; |
| debug_assert_eq!(record.legal_grids[t].len(), 64); |
| legal_move_grid[grid_offset..grid_offset + 64] |
| .copy_from_slice(&record.legal_grids[t]); |
| } |
|
|
| |
| for t in 0..length { |
| let promo_offset = (b * max_ply + t) * 44 * 4; |
| |
| let flat: &[bool; 176] = unsafe { |
| &*(&record.legal_promos[t] as *const [[bool; 4]; 44] as *const [bool; 176]) |
| }; |
| legal_promo_mask[promo_offset..promo_offset + 176].copy_from_slice(flat); |
| } |
| } |
|
|
| TrainingBatch { |
| move_ids, |
| game_lengths, |
| legal_move_grid, |
| legal_promo_mask, |
| termination_codes, |
| batch_size, |
| max_ply, |
| } |
| } |
|
|
| |
| pub fn generate_random_games(n_games: usize, max_ply: usize, seed: u64) -> GameBatch { |
| let seeds = derive_game_seeds(seed, n_games); |
| let results: Vec<(Vec<u16>, u16, Termination)> = seeds |
| .into_par_iter() |
| .map(|s| generate_one_game(s, max_ply)) |
| .collect(); |
|
|
| let mut move_ids = vec![0i16; n_games * max_ply]; |
| let mut game_lengths = Vec::with_capacity(n_games); |
| let mut termination_codes = Vec::with_capacity(n_games); |
|
|
| for (b, (moves, length, term)) in results.iter().enumerate() { |
| game_lengths.push(*length as i16); |
| termination_codes.push(term.as_u8()); |
|
|
| for t in 0..(*length as usize) { |
| move_ids[b * max_ply + t] = moves[t] as i16; |
| } |
| } |
|
|
| GameBatch { |
| move_ids, |
| game_lengths, |
| termination_codes, |
| n_games, |
| max_ply, |
| } |
| } |
|
|
| |
| pub struct CheckmateTrainingBatch { |
| pub move_ids: Vec<i16>, |
| pub game_lengths: Vec<i16>, |
| pub checkmate_targets: Vec<u64>, |
| pub legal_grids: Vec<u64>, |
| pub n_games: usize, |
| pub max_ply: usize, |
| pub total_generated: usize, |
| } |
|
|
| |
| pub fn generate_checkmate_training_batch( |
| n_games: usize, |
| max_ply: usize, |
| seed: u64, |
| ) -> CheckmateTrainingBatch { |
| let (examples, total_generated) = generate_checkmate_examples(seed, max_ply, n_games); |
| let n = examples.len(); |
|
|
| let mut move_ids = vec![0i16; n * max_ply]; |
| let mut game_lengths = Vec::with_capacity(n); |
| let mut checkmate_targets = vec![0u64; n * 64]; |
| let mut legal_grids = vec![0u64; n * 64]; |
|
|
| for (b, ex) in examples.iter().enumerate() { |
| game_lengths.push(ex.game_length as i16); |
| for t in 0..(ex.game_length as usize).min(max_ply) { |
| move_ids[b * max_ply + t] = ex.move_ids[t] as i16; |
| } |
| checkmate_targets[b * 64..(b + 1) * 64].copy_from_slice(&ex.checkmate_grid); |
| legal_grids[b * 64..(b + 1) * 64].copy_from_slice(&ex.legal_grid); |
| } |
|
|
| CheckmateTrainingBatch { |
| move_ids, |
| game_lengths, |
| checkmate_targets, |
| legal_grids, |
| n_games: n, |
| max_ply, |
| total_generated, |
| } |
| } |
|
|
| |
| |
| pub fn generate_completed_games(n_games: usize, max_ply: usize, seed: u64) -> GameBatch { |
| let batch_size = 4096.max(n_games * 2); |
| let mut collected: Vec<(Vec<u16>, u16, Termination)> = Vec::with_capacity(n_games); |
| let mut game_seed = seed; |
|
|
| while collected.len() < n_games { |
| let seeds = derive_game_seeds(game_seed, batch_size); |
| let results: Vec<(Vec<u16>, u16, Termination)> = seeds |
| .into_par_iter() |
| .map(|s| generate_one_game(s, max_ply)) |
| .collect(); |
|
|
| game_seed += batch_size as u64; |
|
|
| for result in results { |
| if result.2 != Termination::PlyLimit { |
| collected.push(result); |
| if collected.len() >= n_games { |
| break; |
| } |
| } |
| } |
| } |
|
|
| let mut move_ids = vec![0i16; n_games * max_ply]; |
| let mut game_lengths = Vec::with_capacity(n_games); |
| let mut termination_codes = Vec::with_capacity(n_games); |
|
|
| for (b, (moves, length, term)) in collected.iter().enumerate() { |
| game_lengths.push(*length as i16); |
| termination_codes.push(term.as_u8()); |
| for t in 0..(*length as usize) { |
| move_ids[b * max_ply + t] = moves[t] as i16; |
| } |
| } |
|
|
| GameBatch { |
| move_ids, |
| game_lengths, |
| termination_codes, |
| n_games, |
| max_ply, |
| } |
| } |
|
|
| |
| |
| pub fn generate_checkmate_games( |
| n_white_wins: usize, |
| n_black_wins: usize, |
| max_ply: usize, |
| seed: u64, |
| ) -> (GameBatch, usize) { |
| use std::sync::atomic::{AtomicUsize, Ordering}; |
|
|
| let batch_size = 4096; |
| let target_total = n_white_wins + n_black_wins; |
|
|
| let mut collected_white: Vec<(Vec<u16>, u16)> = Vec::with_capacity(n_white_wins); |
| let mut collected_black: Vec<(Vec<u16>, u16)> = Vec::with_capacity(n_black_wins); |
| let mut total_generated: usize = 0; |
| let mut game_seed = seed; |
|
|
| while collected_white.len() < n_white_wins || collected_black.len() < n_black_wins { |
| |
| let batch_seeds = derive_game_seeds(game_seed, batch_size); |
| let results: Vec<(Vec<u16>, u16, Termination)> = batch_seeds |
| .into_par_iter() |
| .map(|s| generate_one_game(s, max_ply)) |
| .collect(); |
|
|
| game_seed += batch_size as u64; |
| total_generated += batch_size; |
|
|
| for (moves, length, term) in results { |
| if term != Termination::Checkmate { |
| continue; |
| } |
| |
| if length % 2 == 1 { |
| if collected_white.len() < n_white_wins { |
| collected_white.push((moves, length)); |
| } |
| } else { |
| if collected_black.len() < n_black_wins { |
| collected_black.push((moves, length)); |
| } |
| } |
| if collected_white.len() >= n_white_wins && collected_black.len() >= n_black_wins { |
| break; |
| } |
| } |
| } |
|
|
| |
| let n_games = collected_white.len() + collected_black.len(); |
| let mut move_ids = vec![0i16; n_games * max_ply]; |
| let mut game_lengths = Vec::with_capacity(n_games); |
| let mut termination_codes = Vec::with_capacity(n_games); |
|
|
| for (b, (moves, length)) in collected_white.iter().chain(collected_black.iter()).enumerate() { |
| game_lengths.push(*length as i16); |
| termination_codes.push(Termination::Checkmate.as_u8()); |
| for t in 0..(*length as usize) { |
| move_ids[b * max_ply + t] = moves[t] as i16; |
| } |
| } |
|
|
| (GameBatch { |
| move_ids, |
| game_lengths, |
| termination_codes, |
| n_games, |
| max_ply, |
| }, total_generated) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| pub struct CLMBatch { |
| pub input_ids: Vec<i16>, |
| pub targets: Vec<i16>, |
| pub loss_mask: Vec<bool>, |
| pub move_ids: Vec<i16>, |
| pub game_lengths: Vec<i16>, |
| pub termination_codes: Vec<u8>, |
| pub batch_size: usize, |
| pub seq_len: usize, |
| pub max_ply: usize, |
| } |
|
|
| |
| |
| |
| |
| pub fn generate_clm_batch( |
| batch_size: usize, |
| seq_len: usize, |
| seed: u64, |
| discard_ply_limit: bool, |
| ) -> CLMBatch { |
| let max_ply = seq_len - 1; |
|
|
| let game_batch = if discard_ply_limit { |
| generate_completed_games(batch_size, max_ply, seed) |
| } else { |
| generate_random_games(batch_size, max_ply, seed) |
| }; |
|
|
| let mut input_ids = vec![0i16; batch_size * seq_len]; |
| let mut targets = vec![0i16; batch_size * seq_len]; |
| let mut loss_mask = vec![false; batch_size * seq_len]; |
|
|
| for b in 0..batch_size { |
| let gl = game_batch.game_lengths[b] as usize; |
| let term = match game_batch.termination_codes[b] { |
| 0 => Termination::Checkmate, |
| 1 => Termination::Stalemate, |
| 2 => Termination::SeventyFiveMoveRule, |
| 3 => Termination::FivefoldRepetition, |
| 4 => Termination::InsufficientMaterial, |
| _ => Termination::PlyLimit, |
| }; |
| let outcome = vocab::termination_to_outcome(term, game_batch.game_lengths[b] as u16); |
|
|
| let row = b * seq_len; |
|
|
| |
| input_ids[row] = outcome as i16; |
|
|
| |
| for t in 0..gl { |
| input_ids[row + 1 + t] = game_batch.move_ids[b * max_ply + t]; |
| } |
| |
|
|
| |
| for t in 0..(seq_len - 1) { |
| targets[row + t] = input_ids[row + t + 1]; |
| } |
| |
|
|
| |
| for t in 0..=gl { |
| loss_mask[row + t] = true; |
| } |
| } |
|
|
| CLMBatch { |
| input_ids, |
| targets, |
| loss_mask, |
| move_ids: game_batch.move_ids, |
| game_lengths: game_batch.game_lengths, |
| termination_codes: game_batch.termination_codes, |
| batch_size, |
| seq_len, |
| max_ply, |
| } |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
|
|
| #[test] |
| fn test_training_batch() { |
| let batch = generate_training_batch(4, 256, 42); |
| assert_eq!(batch.move_ids.len(), 4 * 256); |
| assert_eq!(batch.game_lengths.len(), 4); |
| assert_eq!(batch.legal_move_grid.len(), 4 * 256 * 64); |
| assert_eq!(batch.legal_promo_mask.len(), 4 * 256 * 44 * 4); |
| assert_eq!(batch.termination_codes.len(), 4); |
|
|
| for &len in &batch.game_lengths { |
| assert!(len > 0 && len <= 256); |
| } |
| } |
|
|
| #[test] |
| fn test_random_games() { |
| let batch = generate_random_games(8, 256, 42); |
| assert_eq!(batch.move_ids.len(), 8 * 256); |
| assert_eq!(batch.game_lengths.len(), 8); |
| } |
|
|
| #[test] |
| fn test_pad_after_game_end() { |
| let batch = generate_training_batch(2, 256, 42); |
| for b in 0..2 { |
| let len = batch.game_lengths[b] as usize; |
| if len < 256 { |
| assert_eq!( |
| batch.move_ids[b * 256 + len], |
| vocab::PAD_TOKEN as i16, |
| "Position game_length should be PAD (0)" |
| ); |
| } |
| |
| for t in len..256 { |
| assert_eq!( |
| batch.move_ids[b * 256 + t], |
| 0, |
| "Position {} (after game_length={}) should be PAD", t, len |
| ); |
| } |
| } |
| } |
|
|
| #[test] |
| fn test_batch_deterministic() { |
| let b1 = generate_training_batch(4, 256, 99); |
| let b2 = generate_training_batch(4, 256, 99); |
| assert_eq!(b1.move_ids, b2.move_ids); |
| assert_eq!(b1.game_lengths, b2.game_lengths); |
| assert_eq!(b1.legal_move_grid, b2.legal_move_grid); |
| } |
|
|
| #[test] |
| fn test_clm_batch_format() { |
| let seq_len = 256; |
| let batch = generate_clm_batch(8, seq_len, 42, false); |
| assert_eq!(batch.input_ids.len(), 8 * seq_len); |
| assert_eq!(batch.targets.len(), 8 * seq_len); |
| assert_eq!(batch.loss_mask.len(), 8 * seq_len); |
| assert_eq!(batch.move_ids.len(), 8 * (seq_len - 1)); |
| assert_eq!(batch.game_lengths.len(), 8); |
|
|
| for b in 0..8 { |
| let gl = batch.game_lengths[b] as usize; |
| let row = b * seq_len; |
|
|
| |
| let outcome = batch.input_ids[row]; |
| assert!(outcome >= vocab::OUTCOME_BASE as i16 && outcome <= vocab::PLY_LIMIT as i16, |
| "Position 0 should be outcome token, got {}", outcome); |
|
|
| |
| for t in 1..=gl { |
| let tok = batch.input_ids[row + t]; |
| assert!(tok >= 1 && tok <= 4272, |
| "Position {} should be move token, got {}", t, tok); |
| } |
|
|
| |
| for t in (gl + 1)..seq_len { |
| assert_eq!(batch.input_ids[row + t], 0, |
| "Position {} should be PAD, got {}", t, batch.input_ids[row + t]); |
| } |
|
|
| |
| for t in 0..(seq_len - 1) { |
| assert_eq!(batch.targets[row + t], batch.input_ids[row + t + 1], |
| "targets[{}] should equal input_ids[{}]", t, t + 1); |
| } |
| assert_eq!(batch.targets[row + seq_len - 1], 0, "Last target should be PAD"); |
|
|
| |
| assert_eq!(batch.targets[row + gl], 0, "Target at game_length should be PAD"); |
|
|
| |
| for t in 0..=gl { |
| assert!(batch.loss_mask[row + t], |
| "loss_mask[{}] should be true (gl={})", t, gl); |
| } |
| for t in (gl + 1)..seq_len { |
| assert!(!batch.loss_mask[row + t], |
| "loss_mask[{}] should be false (gl={})", t, gl); |
| } |
| } |
| } |
|
|
| #[test] |
| fn test_clm_batch_deterministic() { |
| let b1 = generate_clm_batch(4, 256, 99, false); |
| let b2 = generate_clm_batch(4, 256, 99, false); |
| assert_eq!(b1.input_ids, b2.input_ids); |
| assert_eq!(b1.targets, b2.targets); |
| assert_eq!(b1.loss_mask, b2.loss_mask); |
| assert_eq!(b1.game_lengths, b2.game_lengths); |
| } |
|
|
| #[test] |
| fn test_clm_batch_outcome_correctness() { |
| let batch = generate_clm_batch(32, 256, 42, false); |
| for b in 0..32 { |
| let gl = batch.game_lengths[b] as usize; |
| let tc = batch.termination_codes[b]; |
| let expected = vocab::termination_to_outcome( |
| match tc { |
| 0 => Termination::Checkmate, |
| 1 => Termination::Stalemate, |
| 2 => Termination::SeventyFiveMoveRule, |
| 3 => Termination::FivefoldRepetition, |
| 4 => Termination::InsufficientMaterial, |
| _ => Termination::PlyLimit, |
| }, |
| gl as u16, |
| ); |
| assert_eq!(batch.input_ids[b * 256] as u16, expected, |
| "Game {} outcome mismatch: tc={}, gl={}", b, tc, gl); |
| } |
| } |
| } |
|
|