Spaces:
Running
Running
| """ | |
| Model architectures and configurations for summarization | |
| """ | |
| import logging | |
| from typing import Optional, List | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM, | |
| AutoModelForCausalLM, | |
| pipeline | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class SummarizationModelLoader: | |
| """Load and manage pre-trained summarization models.""" | |
| # Popular pre-trained models for summarization | |
| AVAILABLE_MODELS = { | |
| # Fast & Lightweight models (RECOMMENDED) | |
| 't5-small': 'google-t5/t5-small', | |
| 't5-base': 'google-t5/t5-base', | |
| 't5-large': 'google-t5/t5-large', | |
| 'bart-base': 'facebook/bart-base', | |
| 'bart-large-cnn': 'facebook/bart-large-cnn', | |
| 'pegasus-arxiv': 'google/pegasus-arxiv', | |
| 'pegasus-pubmed': 'google/pegasus-pubmed', | |
| 'led': 'allenai/led-base-16384', # For long documents | |
| # Multilingual models (supports 50+ languages) | |
| 'mbart-50': 'facebook/mbart-large-50', | |
| 'mbart-50-small': 'facebook/mbart-large-50-small', # FASTEST - recommended for speed | |
| 'mt5-small': 'google/mt5-small', # Multilingual T5 (100+ languages) | |
| 'mt5-base': 'google/mt5-base', | |
| } | |
| # Supported languages for multilingual models | |
| SUPPORTED_LANGUAGES = { | |
| 'english': 'en_XX', | |
| 'spanish': 'es_XX', | |
| 'french': 'fr_XX', | |
| 'german': 'de_DE', | |
| 'italian': 'it_IT', | |
| 'portuguese': 'pt_XX', | |
| 'chinese': 'zh_CN', | |
| 'japanese': 'ja_XX', | |
| 'korean': 'ko_KR', | |
| 'arabic': 'ar_AR', | |
| 'hindi': 'hi_IN', | |
| 'russian': 'ru_RU', | |
| 'turkish': 'tr_TR', | |
| 'vietnamese': 'vi_VN', | |
| 'thai': 'th_TH', | |
| } | |
| def __init__(self, model_name: str = 't5-small', device: Optional[str] = None, language: str = 'english'): | |
| """ | |
| Initialize model loader. | |
| Args: | |
| model_name: Name of the model to load (default: t5-small for speed) | |
| device: Device to load model on ('cpu' or 'cuda') | |
| language: Language for multilingual models | |
| """ | |
| self.model_name = model_name | |
| self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.language = language | |
| self.language_code = self.SUPPORTED_LANGUAGES.get(language.lower(), 'en_XX') | |
| self.model = None | |
| self.tokenizer = None | |
| logger.info(f"Using device: {self.device}") | |
| logger.info(f"Language: {language} (code: {self.language_code})") | |
| def load_model(self, model_path: Optional[str] = None) -> tuple: | |
| """ | |
| Load model and tokenizer. | |
| Args: | |
| model_path: Path to local model or HuggingFace model ID | |
| Returns: | |
| Tuple of (model, tokenizer) | |
| """ | |
| path = model_path or self.AVAILABLE_MODELS.get(self.model_name, self.model_name) | |
| try: | |
| logger.info(f"Loading tokenizer from {path}") | |
| self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True) # Fast tokenizer | |
| # Set language for multilingual models | |
| if 'mbart' in path.lower() or 'mt5' in path.lower(): | |
| self.tokenizer.src_lang = self.language_code | |
| logger.info(f"Loading model from {path}") | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(path) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| logger.info(f"Model loaded successfully on {self.device}") | |
| return self.model, self.tokenizer | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| raise | |
| def get_model_info(self) -> dict: | |
| """Get information about loaded model.""" | |
| if self.model is None: | |
| return {"status": "Model not loaded"} | |
| return { | |
| "model_name": self.model_name, | |
| "device": self.device, | |
| "parameters": sum(p.numel() for p in self.model.parameters()), | |
| "trainable_parameters": sum(p.numel() for p in self.model.parameters() if p.requires_grad) | |
| } | |
| class IntentClassifier: | |
| """Classify user intent for summarization.""" | |
| INTENT_TYPES = { | |
| 'technical_overview': 'high-level technical summary', | |
| 'detailed_analysis': 'comprehensive technical analysis', | |
| 'methodology': 'focus on methods and approach', | |
| 'results': 'focus on results and findings', | |
| 'conclusion': 'focus on conclusions and implications', | |
| 'abstract': 'paper abstract-like summary', | |
| } | |
| def __init__(self): | |
| """Initialize intent classifier.""" | |
| self.intent_prompts = {} | |
| for intent_key, intent_desc in self.INTENT_TYPES.items(): | |
| self.intent_prompts[intent_key] = f"Provide {intent_desc}" | |
| def classify_intent(self, user_input: str) -> str: | |
| """ | |
| Classify user intent from input text. | |
| Args: | |
| user_input: User's intent description | |
| Returns: | |
| Classified intent type | |
| """ | |
| user_input_lower = user_input.lower() | |
| # Simple keyword matching (can be enhanced with ML model) | |
| for intent_key in self.INTENT_TYPES.keys(): | |
| if intent_key.replace('_', ' ') in user_input_lower: | |
| return intent_key | |
| # Default to technical_overview | |
| return 'technical_overview' | |
| def get_prompt_for_intent(self, intent: str) -> str: | |
| """Get customized prompt for specific intent.""" | |
| return self.intent_prompts.get(intent, self.intent_prompts['technical_overview']) | |
| class ContextPreserver: | |
| """Preserve important context during summarization.""" | |
| def __init__(self): | |
| """Initialize context preserver.""" | |
| self.important_patterns = { | |
| 'method': r'(?:method|approach|technique|algorithm)', | |
| 'metric': r'(?:metric|accuracy|precision|recall|f1|score)', | |
| 'dataset': r'(?:dataset|corpus|benchmark)', | |
| 'baseline': r'(?:baseline|state-of-the-art|sota)', | |
| } | |
| def extract_important_content(self, text: str) -> List[str]: | |
| """ | |
| Extract important content that should be preserved. | |
| Args: | |
| text: Document text | |
| Returns: | |
| List of important content snippets | |
| """ | |
| important_snippets = [] | |
| sentences = text.split('.') | |
| for sentence in sentences: | |
| for pattern in self.important_patterns.values(): | |
| if len(sentence) > 20: # Skip very short sentences | |
| important_snippets.append(sentence.strip()) | |
| break | |
| return important_snippets | |
| def weight_content(self, sentences: List[str]) -> List[float]: | |
| """ | |
| Assign importance weights to sentences. | |
| Args: | |
| sentences: List of sentences | |
| Returns: | |
| List of importance weights | |
| """ | |
| weights = [] | |
| for sentence in sentences: | |
| weight = 1.0 | |
| # Boost weight for important content | |
| for pattern in self.important_patterns.values(): | |
| if any(word in sentence.lower() for word in pattern.split('|')): | |
| weight += 0.5 | |
| weights.append(min(weight, 2.0)) # Cap at 2.0 | |
| return weights | |