NLP-RAG / retriever /processor.py
Qar-Raz's picture
hf-space: deploy branch without frontend/data/results
c7256ee
from langchain_text_splitters import (
RecursiveCharacterTextSplitter,
CharacterTextSplitter,
SentenceTransformersTokenTextSplitter,
NLTKTextSplitter
)
from langchain_experimental.text_splitter import SemanticChunker
from langchain_huggingface import HuggingFaceEmbeddings
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any, Optional
import nltk
nltk.download('punkt_tab', quiet=True)
import pandas as pd
import re
class MarkdownTextSplitter:
"""
Custom markdown header chunking strategy.
Splits text by headers in a hierarchical manner:
- First checks h1 (#) headers
- If h1 content <= max_chars, accepts it as a chunk
- If h1 content > max_chars, splits into h2 headers
- If any h2 > max_chars, splits into h3, and so on
"""
def __init__(self, max_chars: int = 4000):
self.max_chars = max_chars
self.headers = ["\n# ", "\n## ", "\n### ", "\n#### "]
def split_text(self, text: str) -> List[str]:
"""Split text using markdown header hierarchy."""
return self._split_by_header(text, 0)
def _split_by_header(self, content: str, header_level: int) -> List[str]:
"""
Recursively split content by header levels.
Args:
content: The text content to split
header_level: Current header level (0=h1, 1=h2, etc.)
Returns:
List of text chunks
"""
# If content is within limit, return it as is
if len(content) <= self.max_chars:
return [content]
# If we've exhausted all header levels, return as single chunk
if header_level >= len(self.headers):
return [content]
# Split by current header level
header = self.headers[header_level]
parts = re.split(f'(?={re.escape(header)})', content)
# If no split occurred (no headers found at this level), try next level
if len(parts) == 1:
return self._split_by_header(content, header_level + 1)
result = []
accumulated = ""
for i, part in enumerate(parts):
# If this single part is too large, recursively split it with next header level
if len(part) > self.max_chars:
# First, flush any accumulated content
if accumulated:
result.append(accumulated)
accumulated = ""
# Then recursively split this large part with deeper headers
result.extend(self._split_by_header(part, header_level + 1))
# If adding this part would exceed limit, flush accumulated and start new
elif accumulated and len(accumulated) + len(part) > self.max_chars:
result.append(accumulated)
accumulated = part
# Accumulate parts that fit together
else:
accumulated += part
# Don't forget the last accumulated part
if accumulated:
result.append(accumulated)
return result
class ChunkProcessor:
def __init__(self, model_name='all-MiniLM-L6-v2', verbose: bool = True, load_hf_embeddings: bool = False):
self.model_name = model_name
self._use_remote_code = self._requires_remote_code(model_name)
st_kwargs = {"trust_remote_code": True} if self._use_remote_code else {}
self.encoder = SentenceTransformer(model_name, **st_kwargs)
self.verbose = verbose
hf_kwargs = {"model_kwargs": {"trust_remote_code": True}} if self._use_remote_code else {}
self.hf_embeddings = HuggingFaceEmbeddings(model_name=model_name, **hf_kwargs) if load_hf_embeddings else None
def _requires_remote_code(self, model_name: str) -> bool:
normalized = (model_name or "").strip().lower()
return normalized.startswith("jinaai/")
def _get_hf_embeddings(self):
if self.hf_embeddings is None:
hf_kwargs = {"model_kwargs": {"trust_remote_code": True}} if self._use_remote_code else {}
self.hf_embeddings = HuggingFaceEmbeddings(model_name=self.model_name, **hf_kwargs)
return self.hf_embeddings
# ------------------------------------------------------------------
# Splitters
# ------------------------------------------------------------------
def get_splitter(self, technique: str, chunk_size: int = 500, chunk_overlap: int = 50, **kwargs):
"""
Factory method to return different chunking strategies.
Strategies:
- "fixed": Character-based, may split mid-sentence
- "recursive": Recursive character splitting with hierarchical separators
- "character": Character-based splitting on paragraph boundaries
- "paragraph": Paragraph-level splitting on \\n\\n boundaries
- "sentence": Sliding window over NLTK sentences
- "semantic": Embedding-based semantic chunking
- "page": Page-level splitting on page markers
"""
if technique == "fixed":
return CharacterTextSplitter(
separator=kwargs.get('separator', ""),
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
is_separator_regex=False
)
elif technique == "recursive":
return RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=kwargs.get('separators', ["\n\n", "\n", ". ", "! ", "? ", "; ", ", ", " ", ""]),
length_function=len,
keep_separator=kwargs.get('keep_separator', True)
)
elif technique == "character":
return CharacterTextSplitter(
separator=kwargs.get('separator', "\n\n"),
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
is_separator_regex=False
)
elif technique == "paragraph":
# Paragraph-level chunking using paragraph breaks
return CharacterTextSplitter(
separator=kwargs.get('separator', "\n\n"),
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
is_separator_regex=False
)
elif technique == "sentence":
# sentence-level chunking using NLTK
return NLTKTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator="\n"
)
elif technique == "semantic":
return SemanticChunker(
self._get_hf_embeddings(),
breakpoint_threshold_type=kwargs.get('breakpoint_threshold_type', "percentile"),
# Using 70 because 95 was giving way too big chunks
breakpoint_threshold_amount=kwargs.get('breakpoint_threshold_amount', 70)
)
elif technique == "page":
# Page-level chunking using page markers
return CharacterTextSplitter(
separator=kwargs.get('separator', "--- Page"),
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
is_separator_regex=False
)
elif technique == "markdown":
# Markdown header chunking - splits by headers with max char limit
return MarkdownTextSplitter(max_chars=chunk_size)
else:
raise ValueError(f"Technique '{technique}' is not supported. Choose from: fixed, recursive, character, paragraph, sentence, semantic, page, markdown")
# ------------------------------------------------------------------
# Processing
# ------------------------------------------------------------------
def process(self, df: pd.DataFrame, technique: str = "recursive", chunk_size: int = 500,
chunk_overlap: int = 50, max_docs: Optional[int] = 5,
verbose: Optional[bool] = None, **kwargs) -> List[Dict[str, Any]]:
"""
Processes a DataFrame into vector-ready chunks.
Args:
df: DataFrame with columns: id, title, url, full_text
technique: Chunking strategy to use
chunk_size: Maximum size of each chunk in characters
chunk_overlap: Overlap between consecutive chunks
max_docs: Number of documents to process (None for all)
verbose: Override instance verbose setting
**kwargs: Additional arguments passed to the splitter
Returns:
List of chunk dicts with embeddings and metadata
"""
should_print = verbose if verbose is not None else self.verbose
required_cols = ['id', 'title', 'url', 'full_text']
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
raise ValueError(f"DataFrame missing required columns: {missing_cols}")
splitter = self.get_splitter(technique, chunk_size, chunk_overlap, **kwargs)
subset_df = df.head(max_docs) if max_docs else df
processed_chunks = []
for _, row in subset_df.iterrows():
if should_print:
self._print_document_header(row['title'], row['url'], technique, chunk_size, chunk_overlap)
raw_chunks = splitter.split_text(row['full_text'])
for i, text in enumerate(raw_chunks):
content = text.page_content if hasattr(text, 'page_content') else text
if should_print:
self._print_chunk(i, content)
processed_chunks.append({
"id": f"{row['id']}-chunk-{i}",
"values": self.encoder.encode(content).tolist(),
"metadata": {
"title": row['title'],
"text": content,
"url": row['url'],
"chunk_index": i,
"technique": technique,
"chunk_size": len(content),
"total_chunks": len(raw_chunks)
}
})
if should_print:
self._print_document_summary(len(raw_chunks))
if should_print:
self._print_processing_summary(len(subset_df), processed_chunks)
return processed_chunks
# ------------------------------------------------------------------
# Printing
# ------------------------------------------------------------------
def _print_document_header(self, title, url, technique, chunk_size, chunk_overlap):
print("\n" + "="*80)
print(f"DOCUMENT: {title}")
print(f"URL: {url}")
print(f"Technique: {technique.upper()} | Chunk Size: {chunk_size} | Overlap: {chunk_overlap}")
print("-" * 80)
def _print_chunk(self, index, content):
print(f"\n[Chunk {index}] ({len(content)} chars):")
print(f" {content}")
def _print_document_summary(self, num_chunks):
print(f"Total Chunks Generated: {num_chunks}")
print("="*80)
def _print_processing_summary(self, num_docs, processed_chunks):
print(f"\nFinished processing {num_docs} documents into {len(processed_chunks)} chunks.")
if processed_chunks:
avg = sum(c['metadata']['chunk_size'] for c in processed_chunks) / len(processed_chunks)
print(f"Average chunk size: {avg:.0f} chars")