| import numpy as np |
| import onnxruntime as ort |
| from omegaconf import OmegaConf |
| from sentencepiece import SentencePieceProcessor |
| from typing import List |
|
|
| def process_text(input_text: str) -> str: |
| spe_path = "sp.model" |
| tokenizer: SentencePieceProcessor = SentencePieceProcessor(spe_path) |
|
|
| |
| onnx_path = "model.onnx" |
| ort_session: ort.InferenceSession = ort.InferenceSession(onnx_path) |
|
|
| |
| config_path = "config.yaml" |
| config = OmegaConf.load(config_path) |
| |
| pre_labels: List[str] = config.pre_labels |
| |
| post_labels: List[str] = config.post_labels |
| |
| null_token = config.get("null_token", "<NULL>") |
| |
| acronym_token = config.get("acronym_token", "<ACRONYM>") |
| |
| max_len = config.max_length |
| |
| languages: List[str] = config.languages |
|
|
| |
| input_ids = [tokenizer.bos_id()] + tokenizer.EncodeAsIds(input_text) + [tokenizer.eos_id()] |
|
|
| |
| input_ids_arr: np.array = np.array([input_ids]) |
|
|
| |
| pre_preds, post_preds, cap_preds, sbd_preds = ort_session.run(None, {"input_ids": input_ids_arr}) |
| |
| pre_preds = pre_preds[0].tolist() |
| post_preds = post_preds[0].tolist() |
| cap_preds = cap_preds[0].tolist() |
| sbd_preds = sbd_preds[0].tolist() |
|
|
| |
| output_texts: List[str] = [] |
| current_chars: List[str] = [] |
|
|
| for token_idx in range(1, len(input_ids) - 1): |
| token = tokenizer.IdToPiece(input_ids[token_idx]) |
| if token.startswith("▁") and current_chars: |
| current_chars.append(" ") |
| |
| pre_label = pre_labels[pre_preds[token_idx]] |
| post_label = post_labels[post_preds[token_idx]] |
| |
| if pre_label != null_token: |
| current_chars.append(pre_label) |
| |
| char_start = 1 if token.startswith("▁") else 0 |
| for token_char_idx, char in enumerate(token[char_start:], start=char_start): |
| |
| if cap_preds[token_idx][token_char_idx]: |
| char = char.upper() |
| |
| current_chars.append(char) |
| |
| if post_label == acronym_token: |
| current_chars.append(".") |
| |
| if post_label != null_token and post_label != acronym_token: |
| current_chars.append(post_label) |
|
|
| |
| if sbd_preds[token_idx]: |
| output_texts.append("".join(current_chars)) |
| current_chars.clear() |
|
|
| |
| output_texts.append("".join(current_chars)) |
|
|
| |
| return "\n".join(output_texts) |
|
|
| |
| input_text = "салам кандайсың" |
| processed_text = process_text(input_text) |
| print("Обработанный текст:") |
| print(processed_text) |
|
|