File size: 14,278 Bytes
f6f45d5 1fed70a f6f45d5 1fed70a f6f45d5 1fed70a f6f45d5 1fed70a f6f45d5 1fed70a f6f45d5 1fed70a f6f45d5 1fed70a f6f45d5 1fed70a f6f45d5 1fed70a f6f45d5 1fed70a f6f45d5 1fed70a f6f45d5 1fed70a f6f45d5 1fed70a f6f45d5 1fed70a f6f45d5 1fed70a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 | """
SinCode v3 β ByT5 Seq2Seq + XLM-RoBERTa MLM Reranker.
Pipeline (per word):
Sinhala script β MLM scores in context (single candidate)
English vocab β ByT5 generates Sinhala alternatives + English kept; MLM picks
Everything else β ByT5 generates top-5 candidates; MLM picks best
"""
import math
import re
import torch
import logging
from typing import List, Tuple, Optional
from transformers import AutoTokenizer, AutoModelForMaskedLM
from core.constants import (
DEFAULT_MLM_MODEL, DEFAULT_BYT5_MODEL,
MAX_CANDIDATES, MIN_ENGLISH_LEN, MIN_ENGLISH_PASSTHROUGH_LEN,
PUNCT_PATTERN,
)
from core.english import ENGLISH_VOCAB
from seq2seq.infer import Transliterator
logger = logging.getLogger(__name__)
_SINHALA_RE = re.compile(r"[\u0D80-\u0DFF]")
# ββ Numeric / special-token passthrough ββββββββββββββββββββββββββββββββββββββ
# These patterns detect tokens that must not go through ByT5 transliteration.
_RE_ORDINAL = re.compile(r"^\d+(st|nd|rd|th)$", re.IGNORECASE) # 1st, 2nd, 3rd
_RE_PURE_NUM = re.compile(r"^\d+(?:[.,]\d+)*$") # 5, 10.30, 9.00
_RE_NUM_RANGE = re.compile(r"^\d+-\d+$") # 2-3, 10-20
_RE_CURRENCY = re.compile(r"^\d+[/]-?$") # 500/-
_RE_AM_PM = re.compile(r"^[ap]\.?m\.?$", re.IGNORECASE) # a.m. p.m. am pm
_RE_NUM_PCT_K = re.compile(r"^(\d[\d,.]*%+)k$", re.IGNORECASE) # 100%k β 100%ΰΆΰ·
_RE_NUM_K = re.compile(r"^(\d[\d,.]*)k$", re.IGNORECASE) # 5000k β 5000ΰΆΰ·
_RE_NUM_TA = re.compile(r"^(\d[\d,.]*)ta$", re.IGNORECASE) # 10.30ta β 10.30ΰΆ§
def _numeric_passthrough(core: str) -> Optional[str]:
"""
Return a (possibly lightly-transformed) value for numeric/special tokens.
Returns None if the token should go through normal ByT5 + MLM processing.
"""
# 100%k β 100%ΰΆΰ· (check before plain Nk to avoid consuming the %)
m = _RE_NUM_PCT_K.match(core)
if m:
return m.group(1) + "ΰΆΰ·"
# 5000k / 10k β 5000ΰΆΰ· / 10ΰΆΰ·
m = _RE_NUM_K.match(core)
if m:
return m.group(1) + "ΰΆΰ·"
# 10.30ta β 10.30ΰΆ§
m = _RE_NUM_TA.match(core)
if m:
return m.group(1) + "ΰΆ§"
# Ordinals: 1st, 2nd, 3rd β¦ β keep as-is
if _RE_ORDINAL.match(core):
return core
# Pure numbers and decimals: 5, 100, 10.30, 9.00 β keep
if _RE_PURE_NUM.match(core):
return core
# Number ranges: 2-3 β keep
if _RE_NUM_RANGE.match(core):
return core
# Currency notation: 500/- β keep
if _RE_CURRENCY.match(core):
return core
# AM/PM markers: a.m. p.m. am pm β keep
if _RE_AM_PM.match(core):
return core
return None
class ScoredCandidate:
__slots__ = ("text", "mlm_score")
def __init__(self, text: str, mlm_score: float):
self.text = text
self.mlm_score = mlm_score
def _is_sinhala(text: str) -> bool:
return bool(_SINHALA_RE.search(text))
class BeamSearchDecoder:
"""
SinCode v3 contextual decoder.
Replaces the rule engine + dictionary + hardcoded maps with a single
ByT5-small seq2seq model fine-tuned on 1,000,000 SinglishβSinhala pairs.
XLM-RoBERTa reranks the top-5 beam candidates by masked-LM probability.
"""
def __init__(
self,
mlm_model_name: str = DEFAULT_MLM_MODEL,
byt5_model_path: str = DEFAULT_BYT5_MODEL,
device: Optional[str] = None,
):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
logger.info("Loading MLM reranker: %s", mlm_model_name)
self.tokenizer = AutoTokenizer.from_pretrained(mlm_model_name)
self.model = AutoModelForMaskedLM.from_pretrained(mlm_model_name)
self.model.to(self.device)
self.model.eval()
logger.info("Loading ByT5 transliterator: %s", byt5_model_path)
self.transliterator = Transliterator(model_path=byt5_model_path, device=self.device)
# ββ Normalization βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@staticmethod
def _softmax_normalize(raw_scores: List[float]) -> List[float]:
if not raw_scores:
return []
if len(raw_scores) == 1:
return [1.0]
max_s = max(raw_scores)
exps = [math.exp(s - max_s) for s in raw_scores]
total = sum(exps)
return [e / total for e in exps]
# ββ MLM batch scoring βββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _batch_mlm_score(
self,
left_contexts: List[str],
right_contexts: List[str],
candidates: List[str],
) -> List[float]:
"""Score each candidate with XLM-RoBERTa multi-mask log-probability."""
if not candidates:
return []
mask = self.tokenizer.mask_token
mask_token_id = self.tokenizer.mask_token_id
cand_token_ids: List[List[int]] = []
for c in candidates:
ids = self.tokenizer.encode(c, add_special_tokens=False)
cand_token_ids.append(ids if ids else [self.tokenizer.unk_token_id])
batch_texts: List[str] = []
for i in range(len(candidates)):
n_masks = len(cand_token_ids[i])
mask_str = " ".join([mask] * n_masks)
parts = [p for p in [left_contexts[i], mask_str, right_contexts[i]] if p]
batch_texts.append(" ".join(parts))
inputs = self.tokenizer(
batch_texts,
return_tensors="pt",
padding=True,
truncation=True,
).to(self.device)
with torch.no_grad():
logits = self.model(**inputs).logits
scores: List[float] = []
for i, target_ids in enumerate(cand_token_ids):
token_ids = inputs.input_ids[i]
mask_positions = (token_ids == mask_token_id).nonzero(as_tuple=True)[0]
if mask_positions.numel() == 0 or not target_ids:
scores.append(-100.0)
continue
n = min(len(target_ids), mask_positions.numel())
total = 0.0
for j in range(n):
pos = mask_positions[j].item()
log_probs = torch.log_softmax(logits[i, pos, :], dim=0)
total += log_probs[target_ids[j]].item()
scores.append(total / n)
return scores
# ββ Public decode βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def decode(self, sentence: str) -> Tuple[str, List[str], List[Tuple[str, List[str]]]]:
"""
Decode a Singlish sentence word-by-word using ByT5 + XLM-RoBERTa MLM.
Returns (transliterated_sentence, trace_logs, word_candidates).
word_candidates is a list of (selected_word, [all_candidates]) per word,
in input word order. Single-candidate words (passthrough/sinhala) have an
empty alternatives list.
"""
words = sentence.split()
if not words:
return "", [], []
# ββ Phase 1: batch ByT5 candidate generation ββββββββββββββββββββββββββ
# Collect only the words that need ByT5 (non-Sinhala), run in one pass
cores: List[str] = []
core_meta: List[tuple] = [] # (index_into_words, prefix, core, suffix, core_lower)
for i, raw in enumerate(words):
match = PUNCT_PATTERN.match(raw)
prefix, core, suffix = match.groups() if match else ("", raw, "")
if not _is_sinhala(core):
# Skip numeric/special tokens β they don't need ByT5
if _numeric_passthrough(core) is not None:
continue
cores.append(core)
core_meta.append((i, prefix, core, suffix, core.lower()))
# Single ByT5 forward pass for all non-Sinhala words
byt5_results: List[List[str]] = (
self.transliterator.batch_candidates(cores, k=MAX_CANDIDATES)
if cores else []
)
byt5_map: dict = {} # word index β list of raw ByT5 strings
for (i, prefix, core, suffix, core_lower), cands in zip(core_meta, byt5_results):
byt5_map[i] = (prefix, suffix, core_lower, cands or [core])
word_infos: List[dict] = []
for i, raw in enumerate(words):
match = PUNCT_PATTERN.match(raw)
raw_prefix, core, raw_suffix = match.groups() if match else ("", raw, "")
if _is_sinhala(core):
word_infos.append({"kind": "sinhala", "candidates": [raw]})
continue
# Numeric/special passthrough β keep or lightly transform, skip MLM
passthrough_val = _numeric_passthrough(core)
if passthrough_val is not None:
word_infos.append({"kind": "passthrough", "candidates": [raw_prefix + passthrough_val + raw_suffix]})
continue
prefix, suffix, core_lower, byt5_cands = byt5_map[i]
sinhala_cands = [prefix + c + suffix for c in byt5_cands]
if core_lower in ENGLISH_VOCAB and len(core_lower) >= MIN_ENGLISH_LEN:
if len(core_lower) >= MIN_ENGLISH_PASSTHROUGH_LEN:
# Long English loanwords (β₯6 chars): unambiguously English in
# code-mixed Singlish β bypass MLM to avoid Sinhala-bias override.
word_infos.append({"kind": "english", "candidates": [raw]})
else:
# Short English vocab words (3β5 chars) may also be Singlish
# homophones (mage/mama/game/call) β let MLM disambiguate.
candidates = [raw] + [c for c in sinhala_cands if c != raw]
word_infos.append({"kind": "singlish", "candidates": candidates[:MAX_CANDIDATES + 1]})
else:
word_infos.append({"kind": "singlish", "candidates": sinhala_cands})
# ββ Phase 2: greedy left-to-right pass (builds dynamic left context) ββ
# Right context is seeded from first ByT5 candidate (pre-decode estimate)
stable_right = [info["candidates"][0] for info in word_infos]
selected_words: List[str] = []
for t, info in enumerate(word_infos):
# English-detected words: always keep raw form β skip MLM (MLM Sinhala bias
# would otherwise score a Sinhala transliteration higher than the English token)
if info["kind"] in ("english", "passthrough", "sinhala"):
selected_words.append(info["candidates"][0])
continue
candidates = info["candidates"]
left_ctx = " ".join(selected_words)
right_ctx = " ".join(stable_right[t + 1:])
raw_mlm = self._batch_mlm_score(
[left_ctx] * len(candidates),
[right_ctx] * len(candidates),
candidates,
)
norm_mlm = self._softmax_normalize(raw_mlm)
best = max(zip(candidates, norm_mlm), key=lambda x: x[1])
selected_words.append(best[0])
# ββ Phase 3: re-score with full decoded sentence as context βββββββββββ
# Right context is now the actual decoded output, not the pre-decode estimate
trace_logs: List[str] = []
final_words: List[str] = []
word_candidates: List[Tuple[str, List[str]]] = [] # (selected, [all_cands])
for t, info in enumerate(word_infos):
raw_word = words[t]
kind = info["kind"]
candidates = info["candidates"]
# English-detected, Sinhala, and passthrough words bypass MLM scoring.
# For English: the MLM model is Sinhala-biased and would otherwise prefer
# a Sinhala transliteration over the correct English token.
if kind == "sinhala":
final_words.append(candidates[0])
word_candidates.append((candidates[0], []))
trace_logs.append(
f"**Step {t+1}: `{raw_word}`** β `{candidates[0]}` (Sinhala passthrough)\n"
)
continue
if kind == "passthrough":
final_words.append(candidates[0])
word_candidates.append((candidates[0], []))
trace_logs.append(
f"**Step {t+1}: `{raw_word}`** β `{candidates[0]}` (numeric/passthrough)\n"
)
continue
if kind == "english":
final_words.append(candidates[0])
word_candidates.append((candidates[0], []))
trace_logs.append(
f"**Step {t+1}: `{raw_word}`** β `{candidates[0]}` (English vocab passthrough)\n"
)
continue
left_ctx = " ".join(final_words)
right_ctx = " ".join(selected_words[t + 1:])
raw_mlm = self._batch_mlm_score(
[left_ctx] * len(candidates),
[right_ctx] * len(candidates),
candidates,
)
norm_mlm = self._softmax_normalize(raw_mlm)
scored = sorted(
[ScoredCandidate(text=c, mlm_score=norm_mlm[i]) for i, c in enumerate(candidates)],
key=lambda x: x.mlm_score,
reverse=True,
)
best = scored[0]
final_words.append(best.text)
word_candidates.append((best.text, [s.text for s in scored]))
trace_logs.append(
f"**Step {t+1}: `{raw_word}`** β `{best.text}` "
f"(MLM={best.mlm_score:.3f})\n"
+ "\n".join(f" - `{s.text}` {s.mlm_score:.3f}" for s in scored)
)
return " ".join(final_words), trace_logs, word_candidates
|