Summarization_Deploy / summarizer.py
USF00's picture
Optimize summarization generation and PDF OCR processing
81a53ea
import math
import logging
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from utils import iter_paragraphs, split_sentences, normalize_text
logger = logging.getLogger(__name__)
# Model config
MODEL_NAME = "facebook/bart-large-cnn"
BATCH_SIZE = 4
NUM_BEAMS = 4
NO_REPEAT_NGRAM_SIZE = 3
EARLY_STOPPING = True
# Chunking config
MAX_INPUT_TOKENS = 1024
HEADROOM_TOKENS = 16
EFFECTIVE_MAX_INPUT = MAX_INPUT_TOKENS - HEADROOM_TOKENS
OVERLAP_SENTENCES = 2
# Output size caps
CHAPTER_MAX_NEW_TOKENS_CAP = 320
CHAPTER_MIN_NEW_TOKENS_FLOOR = 120
BOOK_PARTS = 8
class BookSummarizer:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = None
self.model = None
def load_model(self):
"""Loads the tokenizer and model into memory."""
if self.model is not None:
return
logger.info(f"Loading model {MODEL_NAME} onto {self.device}...")
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(self.device)
if self.device == "cuda":
try:
self.model.half()
except Exception as e:
logger.warning(f"Could not convert model to fp16: {e}")
self.model.eval()
logger.info("Model loaded successfully.")
def tok_len(self, s: str) -> int:
if not self.tokenizer:
self.load_model()
return len(self.tokenizer.encode(s, add_special_tokens=False))
def split_by_tokens(self, s: str, max_len: int, overlap_tokens: int = 64):
if not self.tokenizer:
self.load_model()
ids = self.tokenizer.encode(s, add_special_tokens=False)
if len(ids) <= max_len:
return [s.strip()]
overlap_tokens = max(0, min(overlap_tokens, max_len // 3))
step = max(1, max_len - overlap_tokens)
parts = []
for i in range(0, len(ids), step):
chunk_ids = ids[i:i+max_len]
if not chunk_ids:
continue
t = self.tokenizer.decode(chunk_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True).strip()
if t:
parts.append(t)
return parts
def chunk_text(self, text: str, max_input_tokens: int = EFFECTIVE_MAX_INPUT, overlap_sentences: int = OVERLAP_SENTENCES):
text = normalize_text(text)
if not text:
return []
chunks = []
cur_sents, cur_tok = [], 0
def flush():
nonlocal cur_sents, cur_tok
if cur_sents:
ch = " ".join(cur_sents).strip()
if ch:
chunks.append(ch)
cur_sents, cur_tok = [], 0
for para in iter_paragraphs(text):
for sent in split_sentences(para):
st = sent.strip()
if not st:
continue
st_tok = self.tok_len(st)
if st_tok > max_input_tokens:
flush()
chunks.extend(self.split_by_tokens(st, max_len=max_input_tokens, overlap_tokens=64))
continue
if cur_tok + st_tok <= max_input_tokens:
cur_sents.append(st)
cur_tok += st_tok
else:
prev = cur_sents[:]
flush()
overlap = prev[-overlap_sentences:] if overlap_sentences and prev else []
cur_sents = overlap + [st]
cur_tok = self.tok_len(" ".join(cur_sents))
flush()
return chunks
@torch.no_grad()
def generate_summaries(self, texts, min_new_tokens, max_new_tokens, batch_size=BATCH_SIZE):
if not self.model:
self.load_model()
outs = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
enc = self.tokenizer(
batch, return_tensors="pt",
truncation=True, padding=True,
max_length=EFFECTIVE_MAX_INPUT
).to(self.device)
try:
gen = self.model.generate(
**enc,
num_beams=NUM_BEAMS,
no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE,
min_new_tokens=min_new_tokens,
max_new_tokens=max_new_tokens,
early_stopping=EARLY_STOPPING,
)
except TypeError:
gen = self.model.generate(
**enc,
num_beams=NUM_BEAMS,
no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE,
min_length=min_new_tokens,
max_length=max_new_tokens,
early_stopping=EARLY_STOPPING,
)
decoded = self.tokenizer.batch_decode(gen, skip_special_tokens=True, clean_up_tokenization_spaces=True)
outs.extend([d.strip() for d in decoded])
return outs
def summarize_long_text(self, text: str, min_new: int, max_new: int):
chunks = self.chunk_text(text)
if not chunks:
return ""
chunk_summaries = []
for ch in chunks:
tlen = self.tok_len(ch)
dyn_max = int(min(max_new, max(min_new, round(tlen * 0.18))))
dyn_min = max(30, min(min_new, dyn_max - 10))
chunk_summaries.append(self.generate_summaries([ch], dyn_min, dyn_max, batch_size=1)[0])
if len(chunk_summaries) == 1:
return chunk_summaries[0]
current = chunk_summaries
for _ in range(6):
combined = "\n".join([f"Part {i+1}: {t}" for i, t in enumerate(current)])
if self.tok_len(combined) <= EFFECTIVE_MAX_INPUT:
return self.generate_summaries([combined], min_new, max_new, batch_size=1)[0]
sub_chunks = self.chunk_text(combined, overlap_sentences=1)
current = self.generate_summaries(
sub_chunks,
min_new_tokens=max(60, min_new // 2),
max_new_tokens=max(180, max_new // 2),
batch_size=BATCH_SIZE
)
return "\n".join(current).strip()
def make_big_book_summary(self, chapter_summaries, parts=BOOK_PARTS):
chap_summaries = [s for s in chapter_summaries if s.strip()]
if not chap_summaries:
return ""
n = len(chap_summaries)
group_size = max(1, math.ceil(n / parts))
groups = [chap_summaries[i:i+group_size] for i in range(0, n, group_size)]
part_summaries = []
for gi, g in enumerate(groups):
combined = "\n".join([f"ChapterSummary {gi+1}.{i+1}: {t}" for i, t in enumerate(g)])
ps = self.summarize_long_text(combined, min_new=220, max_new=520)
part_summaries.append(ps.strip())
return "\n\n".join(part_summaries)
summarizer = BookSummarizer()