| | from typing import Literal |
| |
|
| | from transformers import AutoTokenizer |
| | from langchain_text_splitters import RecursiveCharacterTextSplitter, NLTKTextSplitter |
| | from langchain_experimental.text_splitter import SemanticChunker |
| | from langchain_huggingface.embeddings import HuggingFaceEmbeddings |
| |
|
| |
|
| | class Splitter: |
| | """ |
| | Класс описывает функционал разделения текста на чанки тремя способами на выбор: |
| | - рекурсивно разбивая чанки различными разделителями |
| | в порядке возрастания "жесткости" их эффекта; |
| | |
| | - объединяя выделенные с помощью библиотеки NLTK предложения |
| | в чанки определенного размера и с наложением; |
| | |
| | - разбивая текст на семантически связанные блоки |
| | с помощью векторных представлений текстов; |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | mode: Literal["recursive", "nltk", "semantic"], |
| | model_name: str = "deepvk/USER-bge-m3", |
| | chunk_size: int = 256, |
| | chunk_overlap: int = 64, |
| | **splitter_kwargs, |
| | ): |
| | self.chunk_size = chunk_size |
| | self.chunk_overlap = chunk_overlap |
| |
|
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| |
|
| | match mode: |
| |
|
| | case "recursive": |
| | self.splitter = RecursiveCharacterTextSplitter( |
| | separators=[ |
| | "\n### ", "\n## ", "\n# ", |
| | "\n\n", "\n", |
| | "!", "?", ". ", ";", ",", ")", " ", "", |
| | ], |
| | keep_separator="end", |
| | chunk_size=chunk_size, |
| | chunk_overlap=chunk_overlap, |
| | length_function=lambda x: len(self.tokenizer.encode(x, add_special_tokens=False)), |
| | **splitter_kwargs, |
| | ) |
| | self.split_fn = self._recursive_split |
| | |
| | case "nltk": |
| | self.splitter = NLTKTextSplitter( |
| | language="russian", |
| | **splitter_kwargs, |
| | ) |
| | self.split_fn = self._nltk_split |
| | |
| | case "semantic": |
| | self.splitter = SemanticChunker( |
| | HuggingFaceEmbeddings( |
| | model_name=model_name, |
| | encode_kwargs={"normalize_embeddings": True}, |
| | ), |
| | **splitter_kwargs, |
| | ) |
| | self.split_fn = self._semantic_split |
| |
|
| |
|
| | def split_text(self, text: str) -> list[str]: |
| | """ |
| | Доступная пользователю функция разделения текста на чанки |
| | """ |
| | return self.split_fn(text) |
| | |
| |
|
| | def _recursive_split(self, text: str) -> list[str]: |
| | """ |
| | Функция разделения текста на чанки при self.splitter == RecursiveCharacterTextSplitter |
| | """ |
| | return [ |
| | chunk |
| | for chunk in self.splitter.split_text(text) |
| | if any(ch.isalpha() for ch in set(chunk)) |
| | ] |
| | |
| | |
| | def _nltk_split(self, text: str) -> list[str]: |
| | """ |
| | Функция разделения текста на чанки при self.splitter == NLTKTextSplitter |
| | """ |
| | sentences = self.splitter.split_text(text)[0].split("\n\n") |
| | sent_sizes = [ |
| | len(self.tokenizer.encode(sent, add_special_tokens=False)) |
| | for sent in sentences |
| | ] |
| |
|
| | chunks = [] |
| | i, n = 0, len(sentences) |
| | while i < n: |
| | cur_len, cur_texts = 0, [] |
| |
|
| | |
| | j = i |
| | while (j < n) and (cur_len + sent_sizes[j] <= self.chunk_size): |
| | cur_texts.append(sentences[j]) |
| | cur_len += sent_sizes[j] |
| | j += 1 |
| |
|
| | chunks.append(cur_texts) |
| |
|
| | |
| | if j >= n: |
| | break |
| |
|
| | |
| | overlap_len, k = 0, j - 1 |
| | while (k >= i) and (overlap_len + sent_sizes[k] <= self.chunk_overlap): |
| | overlap_len += sent_sizes[k] |
| | k -= 1 |
| |
|
| | |
| | i = k + 1 |
| |
|
| | return chunks |
| | |
| |
|
| | def _semantic_split(self, text: str) -> list[str]: |
| | """ |
| | Функция разделения текста на чанки при self.splitter == SemanticChunker |
| | """ |
| | return self.splitter.split_text(text) |
| |
|