| | |
| | import json |
| | from pathlib import Path |
| | import re |
| | import torch |
| | from transformers import AutoTokenizer, Gemma3ForCausalLM |
| | from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction |
| | from tqdm import tqdm |
| | import os |
| | import torch._dynamo |
| |
|
| |
|
| | torch._dynamo.config.suppress_errors = True |
| | torch.set_float32_matmul_precision('high') |
| | |
| | SRC_LANG, TGT_LANG = "en", "kk" |
| | MODEL_PATH = "/raid/srp_base_model_training/abai_workspace/models/sync_kk_en/checkpoint-final" |
| | TEST_FILE = "/raid/srp_base_model_training/abai_workspace/data/flores/en_to_kk_formatted.jsonl" |
| | OUTPUT_JSON = f"eval_sync_KKEN_data_{SRC_LANG}_to_{TGT_LANG}.json" |
| | MAX_NEW_TOKS = 64 |
| | os.environ["CUDA_VISIBLE_DEVICES"] = "2,3,4,5" |
| | DEVICE = "cuda" |
| | |
| | |
| | def clean_user_field(user_str: str) -> str: |
| | """ |
| | Remove leading <src=xx><tgt=yy> tags and any whitespace/newlines after them. |
| | """ |
| | |
| | return re.sub(r'^<src=[^>]+><tgt=[^>]+>\s*', '', user_str) |
| |
|
| | def load_model_and_tokenizer(): |
| | print(f"Loading model/tokenizer from {MODEL_PATH} β¦") |
| | tok = AutoTokenizer.from_pretrained(MODEL_PATH) |
| | model = Gemma3ForCausalLM.from_pretrained( |
| | MODEL_PATH, |
| | torch_dtype=torch.bfloat16, |
| | device_map="auto", |
| | ) |
| | model.eval() |
| | return tok, model |
| |
|
| | def build_prompt(system: str, user: str) -> str: |
| | return ( |
| | f"<start_of_turn>system\n{system}<end_of_turn>\n" |
| | f"<start_of_turn>user\n{user}<end_of_turn>\n" |
| | f"<start_of_turn>assistant" |
| | ) |
| |
|
| | def run_inference(tok, model, system: str, user: str) -> str: |
| | prompt = build_prompt(system, user) |
| | inputs = tok(prompt, return_tensors="pt", truncation=True).to(model.device) |
| | input_len = inputs["input_ids"].shape[-1] |
| |
|
| | with torch.inference_mode(): |
| | out = model.generate( |
| | **inputs, |
| | max_new_tokens=MAX_NEW_TOKS, |
| | do_sample=False, |
| | eos_token_id=tok.convert_tokens_to_ids("<end_of_turn>"), |
| | pad_token_id=tok.eos_token_id, |
| | ) |
| | gen_ids = out[0][input_len:] |
| | return tok.decode(gen_ids, skip_special_tokens=True).strip() |
| |
|
| | def load_test_examples(path: str): |
| | examples = [] |
| | for line in open(path, encoding="utf-8"): |
| | obj = json.loads(line) |
| | examples.append((obj["system"].strip(), |
| | obj["user"].strip(), |
| | obj["assistant"].strip())) |
| | return examples |
| |
|
| | def evaluate_bleu_nltk(hyps, refs): |
| | """ |
| | Compute corpus-level 4-gram BLEU using NLTK. |
| | - hyps: list of hypothesis strings |
| | - refs: list of reference strings |
| | Returns BLEU in percentage (e.g. 27.53). |
| | """ |
| | |
| | tokenized_hyps = [hyp.split() for hyp in hyps] |
| | |
| | tokenized_refs = [[ref.split()] for ref in refs] |
| |
|
| | |
| | smoothing = SmoothingFunction().method1 |
| |
|
| | |
| | score = corpus_bleu( |
| | tokenized_refs, |
| | tokenized_hyps, |
| | weights=(0.25, 0.25, 0.25, 0.25), |
| | smoothing_function=smoothing, |
| | ) |
| |
|
| | |
| | return round(score, 4) |
| |
|
| | def main(): |
| | tok, model = load_model_and_tokenizer() |
| | examples = load_test_examples(TEST_FILE) |
| | hyps, refs, users = [], [], [] |
| |
|
| | for system, user, assistant in tqdm(examples, desc="Translating"): |
| | clean_user = clean_user_field(user) |
| | hyp = run_inference(tok, model, system, clean_user) |
| | hyps.append(hyp) |
| | refs.append(assistant) |
| | users.append(clean_user) |
| |
|
| | bleu_score = evaluate_bleu_nltk(hyps, refs) |
| |
|
| | |
| | out = { |
| | "model": MODEL_PATH, |
| | "bleu": bleu_score, |
| | "examples": [] |
| | } |
| | for (s, _, r), u_clean, h in zip(examples, users, hyps): |
| | out["examples"].append({ |
| | "system": s, |
| | "user": u_clean, |
| | "reference": r, |
| | "hypothesis": h |
| | }) |
| |
|
| | with open(OUTPUT_JSON, "w", encoding="utf-8") as f: |
| | json.dump(out, f, ensure_ascii=False, indent=2) |
| | print(f"β
Saved cleaned evaluation to {OUTPUT_JSON}") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|