#!/usr/bin/env python3 import json from pathlib import Path from typing import List, Tuple import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader TRAIN_PATH = Path("data/train.jsonl") MODEL_OUT = Path("vil-encoder-v2.pt") SEQ_LEN = 64 EMBED_DIM = 32 BATCH_SIZE = 128 EPOCHS = 12 LR = 1e-3 WEIGHT_DECAY = 1e-5 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" SEED = 918 torch.manual_seed(SEED) np.random.seed(SEED) def encode_triplet(visible: str, braille: str, hanzi: str) -> np.ndarray: text = f"{visible}|{braille}|{hanzi}" arr = np.array([ord(c) % 256 for c in text], dtype=np.float32) if arr.shape[0] < SEQ_LEN: arr = np.pad(arr, (0, SEQ_LEN - arr.shape[0])) else: arr = arr[:SEQ_LEN] arr /= 255.0 return arr def load_rows(path: Path) -> List[dict]: rows: List[dict] = [] with path.open("r", encoding="utf-8") as f: for line in f: line = line.strip() if line: rows.append(json.loads(line)) if not rows: raise RuntimeError(f"No rows loaded from {path}") return rows class PairDataset(Dataset): def __init__(self, rows: List[dict]) -> None: self.rows = rows self.inputs = np.stack([ encode_triplet(r["visible"], r["braille"], r["hanzi"]) for r in rows ]).astype(np.float32) def __len__(self) -> int: return len(self.rows) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: anchor = self.inputs[idx] pos_idx = (idx + 1) % len(self.inputs) positive = self.inputs[pos_idx] return torch.from_numpy(anchor), torch.from_numpy(positive) class Encoder(nn.Module): def __init__(self, input_dim: int = SEQ_LEN, embed_dim: int = EMBED_DIM) -> None: super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, 128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, embed_dim), ) def forward(self, x: torch.Tensor) -> torch.Tensor: z = self.net(x) return nn.functional.normalize(z, dim=-1) def cosine_pull_loss(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return 1.0 - nn.functional.cosine_similarity(a, b).mean() def main() -> None: rows = load_rows(TRAIN_PATH) dataset = PairDataset(rows) loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False) model = Encoder().to(DEVICE) optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY) best_loss = float("inf") history = [] for epoch in range(EPOCHS): model.train() running = 0.0 batches = 0 for x1, x2 in loader: x1 = x1.to(DEVICE) x2 = x2.to(DEVICE) z1 = model(x1) z2 = model(x2) loss = cosine_pull_loss(z1, z2) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() running += float(loss.item()) batches += 1 epoch_loss = running / max(1, batches) history.append(epoch_loss) print(f"epoch={epoch:02d} loss={epoch_loss:.6f}") if epoch_loss < best_loss: best_loss = epoch_loss checkpoint = { "model_state_dict": model.state_dict(), "config": { "input_dim": SEQ_LEN, "embed_dim": EMBED_DIM, }, "history": history, } torch.save(checkpoint, MODEL_OUT) print(f"saved={MODEL_OUT} best_loss={best_loss:.6f}") if __name__ == "__main__": main()