| |
| """ |
| @author:cb |
| @contact:chenbo@bat100.net |
| @time:2023/5/30 14:21 |
| @filename:tokenization.py |
| @software:PyCharm |
| @description: |
| """ |
| import re |
| from transformers import FSMTTokenizer as fsmt |
|
|
|
|
| class FSMTTokenizer(fsmt): |
| space_re = re.compile('\s*(?=[^a-zA-Z0-9 ]+)\s*') |
|
|
| def moses_tokenize(self, text, lang): |
| if lang not in self.cache_moses_tokenizer: |
| moses_tokenizer = self.sm.MosesTokenizer(lang=lang) |
| self.cache_moses_tokenizer[lang] = moses_tokenizer |
| return self.cache_moses_tokenizer[lang].tokenize( |
| text, aggressive_dash_splits=True, return_str=False, escape=False |
| ) |
|
|
| def _switch_to_input_mode(self): |
| self.lang_prefix, self.lang_prefix_id = 'en', 64812 |
|
|
| def _switch_to_target_mode(self): |
| self.lang_prefix, self.lang_prefix_id = 'zh', 64870 |
|
|
| def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): |
| """ |
| Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and |
| adding special tokens. A FAIRSEQ Transformer sequence has the following format: |
| |
| - single sequence: `<s> X </s>` |
| - pair of sequences: `<s> A </s> B </s>` |
| |
| Args: |
| token_ids_0 (`List[int]`): |
| List of IDs to which the special tokens will be added. |
| token_ids_1 (`List[int]`, *optional*): |
| Optional second list of IDs for sequence pairs. |
| |
| Returns: |
| `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. |
| """ |
| sep = [self.sep_token_id] |
| token_ids_0 = [self.lang_prefix_id] + token_ids_0 |
| |
| if token_ids_1 is None: |
| return token_ids_0 + sep |
| return token_ids_0 + sep + token_ids_1 + sep |
|
|
| def moses_pipeline(self, text, lang): |
| text = self.moses_punct_norm(text, lang) |
| return text |
|
|
| def _tokenize(self, text, lang="en", bypass_tokenizer=False): |
| """ |
| 原版FSMTTokenizer会把中文标点英文化,故重写 |
| :param text: |
| :param lang: |
| :param bypass_tokenizer: |
| :return: |
| """ |
| if self.do_lower_case: |
| text = text.lower() |
| if bypass_tokenizer: |
| text = text.split() |
| else: |
| text = self.moses_pipeline(text, lang=self.lang_prefix) |
| text = self.moses_tokenize(text, lang=self.lang_prefix) |
|
|
| split_tokens = [] |
| for token in text: |
| if token: |
| split_tokens.extend(list(self.bpe(token).split(" "))) |
|
|
| return split_tokens |
|
|
| def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): |
| """ |
| |
| :param text: |
| :param is_split_into_words: |
| :param kwargs: |
| :return: |
| """ |
| if kwargs.get('src', True): |
| self._switch_to_input_mode() |
| else: |
| self._switch_to_target_mode() |
| return super(FSMTTokenizer, self).prepare_for_tokenization(text, is_split_into_words=False, **kwargs) |
| |
| def convert_tokens_to_string(self, tokens): |
| """ |
| 删除非英文字母前后的空格,业务上处理更合适 |
| :param tokens: |
| :return: |
| """ |
| tokens = super(FSMTTokenizer, self).convert_tokens_to_string(tokens) |
| tokens = FSMTTokenizer.space_re.sub('', tokens) |
| return tokens |
|
|
|
|
| if __name__ == '__main__': |
| tokenizer = FSMTTokenizer.from_pretrained(r'./') |
| r = tokenizer.tokenize(['hello', 'hi']) |
| print(r) |