Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| import json | |
| from pathlib import Path | |
| from typing import List, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader | |
| TRAIN_PATH = Path("data/train.jsonl") | |
| MODEL_OUT = Path("vil-encoder-v2.pt") | |
| SEQ_LEN = 64 | |
| EMBED_DIM = 32 | |
| BATCH_SIZE = 128 | |
| EPOCHS = 12 | |
| LR = 1e-3 | |
| WEIGHT_DECAY = 1e-5 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| SEED = 918 | |
| torch.manual_seed(SEED) | |
| np.random.seed(SEED) | |
| def encode_triplet(visible: str, braille: str, hanzi: str) -> np.ndarray: | |
| text = f"{visible}|{braille}|{hanzi}" | |
| arr = np.array([ord(c) % 256 for c in text], dtype=np.float32) | |
| if arr.shape[0] < SEQ_LEN: | |
| arr = np.pad(arr, (0, SEQ_LEN - arr.shape[0])) | |
| else: | |
| arr = arr[:SEQ_LEN] | |
| arr /= 255.0 | |
| return arr | |
| def load_rows(path: Path) -> List[dict]: | |
| rows: List[dict] = [] | |
| with path.open("r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| rows.append(json.loads(line)) | |
| if not rows: | |
| raise RuntimeError(f"No rows loaded from {path}") | |
| return rows | |
| class PairDataset(Dataset): | |
| def __init__(self, rows: List[dict]) -> None: | |
| self.rows = rows | |
| self.inputs = np.stack([ | |
| encode_triplet(r["visible"], r["braille"], r["hanzi"]) for r in rows | |
| ]).astype(np.float32) | |
| def __len__(self) -> int: | |
| return len(self.rows) | |
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| anchor = self.inputs[idx] | |
| pos_idx = (idx + 1) % len(self.inputs) | |
| positive = self.inputs[pos_idx] | |
| return torch.from_numpy(anchor), torch.from_numpy(positive) | |
| class Encoder(nn.Module): | |
| def __init__(self, input_dim: int = SEQ_LEN, embed_dim: int = EMBED_DIM) -> None: | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(input_dim, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, embed_dim), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| z = self.net(x) | |
| return nn.functional.normalize(z, dim=-1) | |
| def cosine_pull_loss(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | |
| return 1.0 - nn.functional.cosine_similarity(a, b).mean() | |
| def main() -> None: | |
| rows = load_rows(TRAIN_PATH) | |
| dataset = PairDataset(rows) | |
| loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False) | |
| model = Encoder().to(DEVICE) | |
| optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY) | |
| best_loss = float("inf") | |
| history = [] | |
| for epoch in range(EPOCHS): | |
| model.train() | |
| running = 0.0 | |
| batches = 0 | |
| for x1, x2 in loader: | |
| x1 = x1.to(DEVICE) | |
| x2 = x2.to(DEVICE) | |
| z1 = model(x1) | |
| z2 = model(x2) | |
| loss = cosine_pull_loss(z1, z2) | |
| optimizer.zero_grad(set_to_none=True) | |
| loss.backward() | |
| optimizer.step() | |
| running += float(loss.item()) | |
| batches += 1 | |
| epoch_loss = running / max(1, batches) | |
| history.append(epoch_loss) | |
| print(f"epoch={epoch:02d} loss={epoch_loss:.6f}") | |
| if epoch_loss < best_loss: | |
| best_loss = epoch_loss | |
| checkpoint = { | |
| "model_state_dict": model.state_dict(), | |
| "config": { | |
| "input_dim": SEQ_LEN, | |
| "embed_dim": EMBED_DIM, | |
| }, | |
| "history": history, | |
| } | |
| torch.save(checkpoint, MODEL_OUT) | |
| print(f"saved={MODEL_OUT} best_loss={best_loss:.6f}") | |
| if __name__ == "__main__": | |
| main() | |