| |
| |
| |
|
|
| from __future__ import annotations |
| from dataclasses import dataclass |
| from typing import Dict, List, Optional, Tuple, Union |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from copy import deepcopy |
|
|
|
|
| @dataclass |
| class NTMConfig: |
| input_dim: int |
| output_dim: int |
| controller_dim: int = 128 |
| controller_layers: int = 1 |
| memory_slots: int = 128 |
| memory_dim: int = 32 |
| heads_read: int = 1 |
| heads_write: int = 1 |
| init_std: float = 0.1 |
|
|
| |
| |
|
|
| class NeuralTuringMachine(nn.Module): |
| def __init__(self, cfg: NTMConfig): |
| super().__init__() |
| self.cfg = cfg |
| R, W, Dm = cfg.heads_read, cfg.heads_write, cfg.memory_dim |
|
|
| ctrl_in = cfg.input_dim + R * Dm |
| self.controller = nn.LSTMCell(ctrl_in, cfg.controller_dim) |
|
|
| iface_read = R * (Dm + 1) |
| iface_write = W * (Dm + 1 + Dm + Dm) |
| self.interface = nn.Linear(cfg.controller_dim, iface_read + iface_write) |
| self.output = nn.Linear(cfg.controller_dim + R * Dm, cfg.output_dim) |
|
|
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| nn.init.zeros_(m.bias) |
| if isinstance(m, nn.LSTMCell): |
| nn.init.xavier_uniform_(m.weight_ih) |
| nn.init.orthogonal_(m.weight_hh) |
| nn.init.zeros_(m.bias_ih) |
| nn.init.zeros_(m.bias_hh) |
| hs = m.bias_ih.shape[0] // 4 |
| m.bias_ih.data[hs:2*hs].fill_(1.0) |
| m.bias_hh.data[hs:2*hs].fill_(1.0) |
|
|
|
|
| def initial_state(self, batch_size: int, device=None): |
| cfg = self.cfg |
| device = device or next(self.parameters()).device |
|
|
| M = torch.zeros(batch_size, cfg.memory_slots, cfg.memory_dim, device=device) |
| if cfg.init_std > 0: |
| M.normal_(0.0, cfg.init_std) |
|
|
| w_r = torch.ones(batch_size, cfg.heads_read, cfg.memory_slots, device=device) / cfg.memory_slots |
| w_w = torch.ones(batch_size, cfg.heads_write, cfg.memory_slots, device=device) / cfg.memory_slots |
| r = torch.zeros(batch_size, cfg.heads_read, cfg.memory_dim, device=device) |
| h = torch.zeros(batch_size, cfg.controller_dim, device=device) |
| c = torch.zeros(batch_size, cfg.controller_dim, device=device) |
|
|
| return {'M': M, 'w_r': w_r, 'w_w': w_w, 'r': r, 'h': h, 'c': c} |
|
|
| def step(self, x: torch.Tensor, state: Dict[str, torch.Tensor]): |
| cfg = self.cfg |
| B = x.shape[0] |
|
|
| ctrl_in = torch.cat([x, state['r'].view(B, -1)], dim=-1) |
| h, c = self.controller(ctrl_in, (state['h'], state['c'])) |
| iface = self.interface(h) |
| R, W, Dm = cfg.heads_read, cfg.heads_write, cfg.memory_dim |
|
|
| offset = 0 |
| k_r = iface[:, offset:offset + R * Dm].view(B, R, Dm) |
| offset += R * Dm |
| beta_r = F.softplus(iface[:, offset:offset + R]) |
| offset += R |
|
|
| k_w = iface[:, offset:offset + W * Dm].view(B, W, Dm) |
| offset += W * Dm |
| beta_w = F.softplus(iface[:, offset:offset + W]) |
| offset += W |
| erase = torch.sigmoid(iface[:, offset:offset + W * Dm]).view(B, W, Dm) |
| offset += W * Dm |
| add = torch.tanh(iface[:, offset:offset + W * Dm]).view(B, W, Dm) |
|
|
| def address(M, k, beta, prev_weight=None): |
| M_norm = torch.norm(M, dim=-1, keepdim=True).clamp_min(1e-8) |
| k_norm = torch.norm(k, dim=-1, keepdim=True).clamp_min(1e-8) |
| cos_sim = torch.sum(M.unsqueeze(1) * k.unsqueeze(2), dim=-1) / ( |
| M_norm.squeeze(-1).unsqueeze(1) * k_norm.squeeze(-1).unsqueeze(-1) |
| ) |
| content_logits = beta.unsqueeze(-1) * cos_sim |
| if prev_weight is not None: |
| content_logits = content_logits + 0.02 * prev_weight |
| return F.softmax(content_logits, dim=-1) |
|
|
|
|
| w_r = address(state['M'], k_r, beta_r, prev_weight=state.get('w_r')) |
| w_w = address(state['M'], k_w, beta_w, prev_weight=state.get('w_w')) |
| r = torch.sum(w_r.unsqueeze(-1) * state['M'].unsqueeze(1), dim=2) |
|
|
| M = state['M'] |
| if W > 0: |
| erase_term = torch.prod(1 - w_w.unsqueeze(-1) * erase.unsqueeze(2), dim=1) |
| M = M * erase_term |
| add_term = torch.sum(w_w.unsqueeze(-1) * add.unsqueeze(2), dim=1) |
| M = M + add_term |
|
|
| y = self.output(torch.cat([h, r.view(B, -1)], dim=-1)) |
|
|
| new_state = {'M': M, 'w_r': w_r, 'w_w': w_w, 'r': r, 'h': h, 'c': c} |
| return y, new_state |
|
|
| def forward(self, x: torch.Tensor, state=None): |
| if x.dim() == 2: |
| if state is None: |
| state = self.initial_state(x.shape[0], x.device) |
| return self.step(x, state) |
|
|
| B, T, _ = x.shape |
| if state is None: |
| state = self.initial_state(B, x.device) |
|
|
| outputs = [] |
| for t in range(T): |
| y, state = self.step(x[:, t], state) |
| outputs.append(y) |
|
|
| return torch.stack(outputs, dim=1), state |
|
|
| @dataclass |
| class EvolutionaryTuringConfig: |
| population_size: int = 100 |
| mutation_rate: float = 0.1 |
| architecture_mutation_rate: float = 0.05 |
| elite_ratio: float = 0.2 |
| max_generations: int = 200 |
| input_dim: int = 8 |
| output_dim: int = 8 |
| device: str = 'cpu' |
| seed: Optional[int] = None |
|
|
| |
| |
|
|
| class FitnessEvaluator: |
| def __init__(self, device: str = 'cpu'): |
| self.device = device |
|
|
| def copy_task(self, ntm: NeuralTuringMachine, seq_len: int = 8, batch_size: int = 16) -> float: |
| with torch.no_grad(): |
| x = torch.randint(0, 2, (batch_size, seq_len, ntm.cfg.input_dim), |
| device=self.device, dtype=torch.float32) |
|
|
| delimiter = torch.zeros(batch_size, 1, ntm.cfg.input_dim, device=self.device) |
| delimiter[:, :, -1] = 1 |
|
|
| input_seq = torch.cat([x, delimiter], dim=1) |
| try: |
| output, _ = ntm(input_seq) |
| T = seq_len |
| D = ntm.cfg.output_dim |
| pred = output[:, -T:, :D] |
| d = min(ntm.cfg.input_dim, D) |
| loss = F.mse_loss(pred[..., :d], x[..., :d]) |
| accuracy = 1.0 / (1.0 + loss.item()) |
| return accuracy |
| except: |
| return 0.0 |
|
|
|
|
| def associative_recall(self, ntm: NeuralTuringMachine, num_pairs: int = 4) -> float: |
| with torch.no_grad(): |
| batch_size = 8 |
| dim = ntm.cfg.input_dim |
| keys = torch.randn(batch_size, num_pairs, dim // 2, device=self.device) |
| values = torch.randn(batch_size, num_pairs, dim // 2, device=self.device) |
| pairs = torch.cat([keys, values], dim=-1) |
|
|
| test_keys = torch.cat([keys, torch.zeros_like(values)], dim=-1) |
| expected_values = torch.cat([torch.zeros_like(keys), values], dim=-1) |
|
|
| input_seq = torch.cat([pairs, test_keys], dim=1) |
| target_seq = torch.cat([torch.zeros_like(pairs), expected_values], dim=1) |
|
|
| try: |
| output, _ = ntm(input_seq) |
| D = ntm.cfg.output_dim |
| d = min(dim, D) |
| loss = F.mse_loss(output[:, num_pairs:, :d], target_seq[:, num_pairs:, :d]) |
| accuracy = 1.0 / (1.0 + loss.item()) |
| return accuracy |
| except: |
| return 0.0 |
|
|
|
|
| def evaluate_fitness(self, ntm: NeuralTuringMachine) -> Dict[str, float]: |
| copy_score = self.copy_task(ntm) |
| recall_score = self.associative_recall(ntm) |
|
|
| param_count = sum(p.numel() for p in ntm.parameters()) |
| efficiency = 1.0 / (1.0 + param_count / 100000) |
|
|
| composite_score = 0.5 * copy_score + 0.3 * recall_score + 0.2 * efficiency |
|
|
| return { |
| 'copy': copy_score, |
| 'recall': recall_score, |
| 'efficiency': efficiency, |
| 'composite': composite_score |
| } |
|
|
| |
| |
|
|
| class EvolutionaryTuringMachine: |
| def __init__(self, cfg: EvolutionaryTuringConfig): |
| self.cfg = cfg |
| self.evaluator = FitnessEvaluator(cfg.device) |
| self.generation = 0 |
| self.best_fitness = 0.0 |
| self.population = [] |
|
|
| if cfg.seed is not None: |
| torch.manual_seed(cfg.seed) |
|
|
| def create_random_config(self) -> NTMConfig: |
| return NTMConfig( |
| input_dim=self.cfg.input_dim, |
| output_dim=self.cfg.output_dim, |
| controller_dim=torch.randint(64, 256, (1,)).item(), |
| controller_layers=torch.randint(1, 3, (1,)).item(), |
| memory_slots=torch.randint(32, 256, (1,)).item(), |
| memory_dim=torch.randint(16, 64, (1,)).item(), |
| heads_read=torch.randint(1, 4, (1,)).item(), |
| heads_write=torch.randint(1, 3, (1,)).item(), |
| init_std=0.1 |
| ) |
|
|
| def mutate_architecture(self, cfg: NTMConfig) -> NTMConfig: |
| new_cfg = deepcopy(cfg) |
|
|
| if torch.rand(1) < self.cfg.architecture_mutation_rate: |
| new_cfg.controller_dim = max(32, new_cfg.controller_dim + torch.randint(-32, 33, (1,)).item()) |
|
|
| if torch.rand(1) < self.cfg.architecture_mutation_rate: |
| new_cfg.memory_slots = max(16, new_cfg.memory_slots + torch.randint(-16, 17, (1,)).item()) |
|
|
| if torch.rand(1) < self.cfg.architecture_mutation_rate: |
| new_cfg.memory_dim = max(8, new_cfg.memory_dim + torch.randint(-8, 9, (1,)).item()) |
|
|
| if torch.rand(1) < self.cfg.architecture_mutation_rate: |
| new_cfg.heads_read = max(1, min(4, new_cfg.heads_read + torch.randint(-1, 2, (1,)).item())) |
|
|
| if torch.rand(1) < self.cfg.architecture_mutation_rate: |
| new_cfg.heads_write = max(1, min(3, new_cfg.heads_write + torch.randint(-1, 2, (1,)).item())) |
|
|
| return new_cfg |
|
|
| def mutate_parameters(self, ntm: NeuralTuringMachine) -> NeuralTuringMachine: |
| new_ntm = NeuralTuringMachine(ntm.cfg).to(self.cfg.device) |
| new_ntm.load_state_dict(deepcopy(ntm.state_dict())) |
| with torch.no_grad(): |
| for p in new_ntm.parameters(): |
| mask = (torch.rand_like(p) < self.cfg.mutation_rate) |
| p.add_(torch.randn_like(p) * 0.01 * mask) |
| return new_ntm |
|
|
|
|
| def crossover(self, parent1: NeuralTuringMachine, parent2: NeuralTuringMachine) -> NeuralTuringMachine: |
| cfg1, cfg2 = parent1.cfg, parent2.cfg |
|
|
| new_cfg = NTMConfig( |
| input_dim=self.cfg.input_dim, |
| output_dim=self.cfg.output_dim, |
| controller_dim=cfg1.controller_dim if torch.rand(1) < 0.5 else cfg2.controller_dim, |
| memory_slots=cfg1.memory_slots if torch.rand(1) < 0.5 else cfg2.memory_slots, |
| memory_dim=cfg1.memory_dim if torch.rand(1) < 0.5 else cfg2.memory_dim, |
| heads_read=cfg1.heads_read if torch.rand(1) < 0.5 else cfg2.heads_read, |
| heads_write=cfg1.heads_write if torch.rand(1) < 0.5 else cfg2.heads_write, |
| init_std=0.1 |
| ) |
|
|
| child = NeuralTuringMachine(new_cfg).to(self.cfg.device) |
| return child |
|
|
| def initialize_population(self): |
| self.population = [] |
| for _ in range(self.cfg.population_size): |
| cfg = self.create_random_config() |
| ntm = NeuralTuringMachine(cfg).to(self.cfg.device) |
| self.population.append(ntm) |
|
|
| def evolve_generation(self) -> Dict[str, float]: |
| fitness_scores = [] |
| for ntm in self.population: |
| fitness = self.evaluator.evaluate_fitness(ntm) |
| fitness_scores.append(fitness['composite']) |
|
|
| sorted_indices = sorted(range(len(fitness_scores)), key=lambda i: fitness_scores[i], reverse=True) |
|
|
| elite_count = int(self.cfg.elite_ratio * self.cfg.population_size) |
| elites = [self.population[i] for i in sorted_indices[:elite_count]] |
|
|
| new_population = elites.copy() |
|
|
| while len(new_population) < self.cfg.population_size: |
| if torch.rand(1) < 0.3 and len(elites) >= 2: |
| parent1, parent2 = torch.randperm(len(elites))[:2] |
| child = self.crossover(elites[parent1], elites[parent2]) |
| else: |
| parent_idx = torch.randint(0, elite_count, (1,)).item() |
| parent = elites[parent_idx] |
|
|
| if torch.rand(1) < 0.5: |
| child = self.mutate_parameters(parent) |
| else: |
| new_cfg = self.mutate_architecture(parent.cfg) |
| child = NeuralTuringMachine(new_cfg).to(self.cfg.device) |
|
|
| new_population.append(child) |
|
|
| self.population = new_population[:self.cfg.population_size] |
| self.generation += 1 |
|
|
| best_fitness = max(fitness_scores) |
| avg_fitness = sum(fitness_scores) / len(fitness_scores) |
| self.best_fitness = max(self.best_fitness, best_fitness) |
|
|
| return { |
| 'generation': self.generation, |
| 'best_fitness': best_fitness, |
| 'avg_fitness': avg_fitness, |
| 'best_ever': self.best_fitness |
| } |
|
|
| def run_evolution(self) -> List[Dict[str, float]]: |
| self.initialize_population() |
|
|
| history = [] |
| for gen in range(self.cfg.max_generations): |
| stats = self.evolve_generation() |
| history.append(stats) |
|
|
| if gen % 10 == 0: |
| print(f"Gen {gen}: Best={stats['best_fitness']:.4f}, Avg={stats['avg_fitness']:.4f}") |
|
|
| return history |
|
|
| def get_best_model(self) -> NeuralTuringMachine: |
| fitness_scores = [] |
| for ntm in self.population: |
| fitness = self.evaluator.evaluate_fitness(ntm) |
| fitness_scores.append(fitness['composite']) |
|
|
| best_idx = max(range(len(fitness_scores)), key=lambda i: fitness_scores[i]) |
| return self.population[best_idx] |
|
|