| from __future__ import annotations
|
|
|
| from dataclasses import dataclass
|
| from math import exp
|
| from time import perf_counter
|
|
|
|
|
| @dataclass(slots=True)
|
| class StepTelemetry:
|
| epoch: int
|
| steps: int
|
| wall_time_sec: float
|
| memory_rss_mb: float
|
| child_processes: int
|
| thread_count: int
|
| predictability_score: float
|
| final_accuracy: float
|
| final_loss: float
|
| learned_gate_mean: float
|
| learned_gate_std: float
|
|
|
|
|
| @dataclass(slots=True)
|
| class GateDemoResult:
|
| initial_accuracy: float
|
| final_accuracy: float
|
| final_loss: float
|
| reached_target: bool
|
| trained_steps: int
|
| target_accuracy: float
|
| learned_gates: list[float]
|
| learned_gate_sample: list[float]
|
| telemetry: list[StepTelemetry]
|
|
|
|
|
| def _process_snapshot() -> tuple[float, int, int]:
|
| try:
|
| import psutil
|
|
|
| process = psutil.Process()
|
| memory_rss_mb = process.memory_info().rss / (1024 * 1024)
|
| child_processes = len(process.children(recursive=True))
|
| thread_count = process.num_threads()
|
| return memory_rss_mb, child_processes, thread_count
|
| except Exception:
|
| return 0.0, 0, 0
|
|
|
|
|
| def run_tinygrad_gate_demo(
|
| steps: int = 80,
|
| batch_size: int = 64,
|
| seed: int = 0,
|
| target_accuracy: float = 0.99,
|
| ) -> GateDemoResult:
|
| try:
|
| from tinygrad import Tensor, nn
|
| from tinygrad.nn.state import get_parameters
|
| except ImportError as exc:
|
| raise RuntimeError("tinygrad demo requires tinygrad to be installed") from exc
|
|
|
| Tensor.manual_seed(seed)
|
|
|
| input_dim = 12
|
| classes = 2
|
| samples = 128
|
|
|
| features = Tensor.randn(samples, input_dim)
|
|
|
| class GatedProbe:
|
| def __init__(self) -> None:
|
| self.base_weights = Tensor.linspace(0.5, 1.5, input_dim).is_param_(False)
|
| self.log_gates = Tensor.zeros(input_dim)
|
|
|
| def __call__(self, x: Tensor) -> Tensor:
|
| score = (x * self.base_weights * self.log_gates.exp()).sum(axis=1)
|
| return Tensor.stack(-score, score, dim=1)
|
|
|
| teacher = GatedProbe()
|
| teacher.log_gates = Tensor.linspace(-0.25, 0.75, input_dim).is_param_(False)
|
| labels = teacher(features).argmax(-1)
|
|
|
| student = GatedProbe()
|
| optimizer = nn.optim.SGD(get_parameters(student), lr=0.8)
|
|
|
| def accuracy(model: GatedProbe) -> float:
|
| logits = model(features)
|
| pred = logits.argmax(-1)
|
| return float((pred == labels).sum().item()) / samples
|
|
|
| initial_accuracy = accuracy(student)
|
|
|
| telemetry: list[StepTelemetry] = []
|
| start_time = perf_counter()
|
| Tensor.training = True
|
| reached_target = False
|
| trained_steps = 0
|
| for epoch in range(1, steps + 1):
|
| batch_x = features
|
| batch_y = labels
|
| optimizer.zero_grad()
|
| loss = student(batch_x).sparse_categorical_crossentropy(batch_y).backward()
|
| optimizer.step()
|
| trained_steps = epoch
|
|
|
| if epoch == steps or epoch % max(1, steps // 8) == 0:
|
| current_logits = student(features)
|
| current_loss = float(current_logits.sparse_categorical_crossentropy(labels).item())
|
| current_accuracy = accuracy(student)
|
| memory_rss_mb, child_processes, thread_count = _process_snapshot()
|
| learned_gates = [float(x) for x in student.log_gates.exp().tolist()]
|
| telemetry.append(
|
| StepTelemetry(
|
| epoch=epoch,
|
| steps=epoch,
|
| wall_time_sec=perf_counter() - start_time,
|
| memory_rss_mb=memory_rss_mb,
|
| child_processes=child_processes,
|
| thread_count=thread_count,
|
| predictability_score=float(exp(-current_loss) * 100.0),
|
| final_accuracy=current_accuracy,
|
| final_loss=current_loss,
|
| learned_gate_mean=sum(learned_gates) / max(len(learned_gates), 1),
|
| learned_gate_std=(
|
| (sum((x - (sum(learned_gates) / max(len(learned_gates), 1))) ** 2 for x in learned_gates) / max(len(learned_gates), 1))
|
| ** 0.5
|
| ),
|
| )
|
| )
|
| if current_accuracy >= target_accuracy:
|
| reached_target = True
|
| break
|
|
|
| Tensor.training = False
|
| final_logits = student(features)
|
| final_loss = float(final_logits.sparse_categorical_crossentropy(labels).item())
|
| final_accuracy = accuracy(student)
|
| learned_gates = [float(x) for x in student.log_gates.exp().tolist()]
|
| gate_sample = learned_gates[:8]
|
|
|
| return GateDemoResult(
|
| initial_accuracy=initial_accuracy,
|
| final_accuracy=final_accuracy,
|
| final_loss=final_loss,
|
| reached_target=reached_target,
|
| trained_steps=trained_steps,
|
| target_accuracy=target_accuracy,
|
| learned_gates=learned_gates,
|
| learned_gate_sample=gate_sample,
|
| telemetry=telemetry,
|
| )
|
|
|