Instructions to use verbit/hebrew_punctuation with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use verbit/hebrew_punctuation with Transformers:
# Load model directly from transformers import BertForPunctuation model = BertForPunctuation.from_pretrained("verbit/hebrew_punctuation", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from typing import List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import BertTokenizer | |
| from src.models import BertForPunctuation | |
| PUNCTUATION_SIGNS = ['', ',', '.', '?'] | |
| PAUSE_TOKEN = 0 | |
| MODEL_NAME = "verbit/hebrew_punctuation" | |
| def tokenize_text( | |
| word_list: List[str], pause_list: List[float], tokenizer: BertTokenizer | |
| ) -> Tuple[List[int], List[int], List[float]]: | |
| """ | |
| Tokenizes text and generates pause list for each word | |
| Args: | |
| word_list: list of words | |
| pause_list: list of pauses after each word in seconds | |
| tokenizer: tokenizer | |
| Returns: | |
| original_word_idx: list of indexes of original words | |
| x: list of indexed words | |
| pause: list of pauses after each word in seconds | |
| """ | |
| assert len(word_list) == len(pause_list), "word_list and pause_list should have the same length" | |
| x, pause = [], [] | |
| # when we do tokenization the number of tokens might be more than one for single word, so we need to keep | |
| # mapping tokens into real words | |
| original_word_idx = [] | |
| for w, p in zip(word_list, pause_list): | |
| tokens = tokenizer.tokenize(w) | |
| p = [p] | |
| # converting tokens to idx, if we have no token for current word then just pad it with 0 to be safe | |
| _x = tokenizer.convert_tokens_to_ids(tokens) if tokens else [0] | |
| if len(_x) > 1: | |
| p = (len(_x) - 1) * [0] + p | |
| x += _x | |
| original_word_idx.append(len(x) - 1) | |
| pause += p | |
| return original_word_idx, x, pause | |
| def gen_model_inputs( | |
| x: List[int], | |
| pause: List[float], | |
| forward_context: int, | |
| backward_context: int, | |
| ) -> torch.Tensor: | |
| """ | |
| Generates inputs for model out of list of indexed words. | |
| Inserts a pause token into the segment | |
| Args: | |
| x: list of indexed words | |
| pause: list of corresponding pauses | |
| forward_context: size of the forward context window | |
| backward_context: size of the backward context window (without the predicted token)` | |
| Returns: | |
| A tensor of model inputs for each indexed word in x | |
| """ | |
| model_input = [] | |
| tokenized_pause = [PAUSE_TOKEN] * len(pause) | |
| x_pad = [0] * backward_context + x + [0] * forward_context | |
| for i in range(len(x)): | |
| segment = x_pad[i : i + backward_context + forward_context + 1] | |
| segment.insert(backward_context + 1, tokenized_pause[i]) | |
| model_input.append(segment) | |
| return torch.tensor(model_input) | |
| def add_punctuation_to_text(text: str, punct_prob: np.ndarray) -> str: | |
| """ | |
| Inserts punctuation to text on provided punctuation string for every word | |
| Args: | |
| text: text to insert punctuation to | |
| punct_prob: matrix of probabilities for each punctuation | |
| Returns: | |
| text with punctuation | |
| """ | |
| words = text.split() | |
| new_words = list() | |
| punctuation_idx = np.argmax(punct_prob, axis=1) | |
| punctuation_list = [PUNCTUATION_SIGNS[i] for i in punctuation_idx] | |
| for word, punctuation_str in zip(words, punctuation_list): | |
| if punctuation_str: | |
| new_words.append(word + punctuation_str) | |
| else: | |
| new_words.append(word) | |
| punct_text = ' '.join(new_words) | |
| return punct_text | |
| def get_prediction( | |
| model: BertForPunctuation, | |
| text: str, | |
| tokenizer: BertTokenizer, | |
| batch_size: int = 16, | |
| backward_context: int = 15, | |
| forward_context: int = 16, | |
| pause_list: Optional[List[float]] = None, | |
| device: str = 'cpu', | |
| ) -> str: | |
| """ | |
| Generates predictions for given list of words. | |
| Args: | |
| model: punctuation model | |
| text: text to predict punctuation for | |
| tokenizer: tokenizer | |
| batch_size: batch size | |
| backward_context: size of the backward context window | |
| forward_context: size of the forward context window | |
| pause_list: list of pauses after each word in seconds | |
| device: device to run model on | |
| Returns: | |
| text with punctuation | |
| """ | |
| word_list = text.split() | |
| if not pause_list: | |
| # make default pauses if pauses are not provided | |
| pause_list = [0.0] * len(word_list) | |
| word_idx, x, pause = tokenize_text(word_list=word_list, pause_list=pause_list, tokenizer=tokenizer) | |
| model_inputs = gen_model_inputs(x, pause, forward_context, backward_context) | |
| model_inputs = model_inputs.index_select(0, torch.LongTensor(word_idx)).to(device) | |
| inputs_length = len(model_inputs) | |
| output = [] | |
| with torch.no_grad(): | |
| for ndx in range(0, inputs_length, batch_size): | |
| o = model(model_inputs[ndx : min(ndx + batch_size, inputs_length)]) | |
| o = F.softmax(o, dim=1) | |
| output.append(o.cpu().data.numpy()) | |
| punct_probabilities_matrix = np.concatenate(output, axis=0) | |
| punct_text = add_punctuation_to_text(text, punct_probabilities_matrix) | |
| return punct_text | |
| def main(): | |
| model = BertForPunctuation.from_pretrained(MODEL_NAME) | |
| tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) | |
| model.eval() | |
| text = """讞讘专转 讜专讘讬讟 驻讬转讞讛 诪注专讻转 诇转诪诇讜诇 讛诪讘讜住住转 注诇 讘讬谞讛 诪诇讗讻讜转讬转 讜讙讜专诐 讗谞讜砖讬 讜砖讜拽讚转 注诇 转诪诇讜诇 注讚讜讬讜转 谞讬爪讜诇讬 砖讜讗讛 | |
| 讗转 讛转讜爪讗讜转 讗驻砖专 诇专讗讜转 讻讘专 讘专砖转 讘讛谉 讞诇拽讬诐 诪注讚讜转讜 砖诇 讟讜讘讬讛 讘讬讬诇住拽讬 砖讛讬讛 诪驻拽讚 讙讚讜讚 讛驻专讟讬讝谞讬诐 讛讬讛讜讚讬诐 讘讘讬讬诇讜专讜住讬讛""" | |
| punct_text = get_prediction( | |
| model=model, | |
| text=text, | |
| tokenizer=tokenizer, | |
| backward_context=model.config.backward_context, | |
| forward_context=model.config.forward_context, | |
| ) | |
| print(punct_text) | |
| if __name__ == "__main__": | |
| main() | |