Spaces:
Running
Running
Pulastya B
Made Major Changes to the Overall User Experience by adding a pipeline visualizer
6b731f7 | """ | |
| Intent Classifier - Determines execution mode for the Reasoning Loop. | |
| Three execution modes: | |
| 1. DIRECT: "Make a scatter plot" β SBERT routing β tool β done | |
| - Clear, specific command with obvious tool mapping | |
| - No reasoning loop needed | |
| 2. INVESTIGATIVE: "Why are customers churning?" β reasoning loop | |
| - Analytical question requiring hypothesis testing | |
| - Reasoning loop drives tool selection | |
| 3. EXPLORATORY: "Analyze this data" β auto-hypothesis β reasoning loop | |
| - Open-ended request with no specific question | |
| - First profiles data, generates hypotheses, then investigates | |
| The classifier uses keyword patterns + semantic features to decide. | |
| This is a lightweight classification (no LLM call needed). | |
| """ | |
| import re | |
| from typing import Optional, Dict, Any, Tuple | |
| from dataclasses import dataclass | |
| class IntentResult: | |
| """Result of intent classification.""" | |
| mode: str # "direct", "investigative", "exploratory" | |
| confidence: float # 0.0-1.0 | |
| reasoning: str # Why this mode was chosen | |
| sub_intent: Optional[str] # More specific intent (e.g., "visualization", "cleaning") | |
| # Patterns that indicate DIRECT mode (specific tool commands) | |
| DIRECT_PATTERNS = [ | |
| # Visualization commands | |
| (r"\b(make|create|generate|build|show|draw|plot)\b.*(scatter|histogram|heatmap|box\s*plot|bar\s*chart|pie\s*chart|line\s*chart|dashboard|time\s*series)", "visualization"), | |
| (r"\b(scatter|histogram|heatmap|boxplot|bar\s*chart)\b.*\b(of|for|between|showing)\b", "visualization"), | |
| # Data cleaning commands | |
| (r"\b(clean|remove|drop|fill|impute|handle)\b.*(missing|null|nan|outlier|duplicate)", "cleaning"), | |
| (r"\b(fix|convert|change)\b.*(data\s*type|dtype|column\s*type)", "cleaning"), | |
| # Feature engineering commands | |
| (r"\b(create|add|extract|generate)\b.*(feature|time\s*feature|interaction|encoding)", "feature_engineering"), | |
| (r"\b(encode|one-hot|label\s*encode|ordinal)\b.*\b(categorical|column)", "feature_engineering"), | |
| # Model training commands | |
| (r"\b(train|build|fit|run)\b.*(model|classifier|regressor|baseline|xgboost|random\s*forest)", "training"), | |
| (r"\b(tune|optimize)\b.*\b(hyperparameter|model|parameter)", "training"), | |
| (r"\b(cross[\s-]?valid)", "training"), | |
| # Profiling commands | |
| (r"\b(profile|describe|summarize)\b.*\b(dataset|data|table|file)", "profiling"), | |
| (r"\b(data\s*quality|quality\s*check|check\s*quality)", "profiling"), | |
| # Report generation | |
| (r"\b(generate|create|build)\b.*\b(report|eda\s*report|profiling\s*report)", "reporting"), | |
| ] | |
| # Patterns that indicate INVESTIGATIVE mode (analytical questions) | |
| INVESTIGATIVE_PATTERNS = [ | |
| # Causal / explanatory questions | |
| (r"\bwhy\b.*(are|is|do|does|did)\b", "causal"), | |
| (r"\bwhat\b.*(cause|driv|factor|reason|explain|lead)", "causal"), | |
| (r"\bwhat\b.*(affect|impact|influence|determine)", "causal"), | |
| # Relationship / correlation questions | |
| (r"\bhow\b.*(does|do|is|are)\b.*\b(relate|correlat|affect|impact|change|vary)", "relationship"), | |
| (r"\b(relationship|correlation|association)\b.*\bbetween\b", "relationship"), | |
| # Comparison questions | |
| (r"\b(differ|compar|contrast)\b.*\bbetween\b", "comparison"), | |
| (r"\bwhich\b.*(better|worse|higher|lower|more|less|best|worst)", "comparison"), | |
| # Pattern / trend questions | |
| (r"\b(pattern|trend|anomal|outlier|unusual|interesting)\b", "pattern"), | |
| (r"\bis\s+there\b.*(pattern|trend|relationship|correlation|difference)", "pattern"), | |
| # Prediction-oriented questions (but NOT direct "train a model" commands) | |
| (r"\bcan\s+(we|i|you)\b.*(predict|forecast|estimate|determine)", "predictive"), | |
| (r"\bwhat\b.*(predict|forecast|expect|happen)", "predictive"), | |
| # Segmentation / grouping questions | |
| (r"\b(segment|group|cluster|categori)\b", "segmentation"), | |
| (r"\bwhat\b.*(type|kind|group|segment)\b.*\b(customer|user|product)", "segmentation"), | |
| ] | |
| # Patterns that indicate EXPLORATORY mode (open-ended requests) | |
| EXPLORATORY_PATTERNS = [ | |
| (r"^analyze\b.*\b(this|the|my)\b.*\b(data|dataset|file|csv)", "general_analysis"), | |
| (r"^(tell|show)\b.*\b(me|us)\b.*\b(about|everything|what)", "general_analysis"), | |
| (r"^(explore|investigate|examine|look\s*(at|into))\b.*\b(this|the|my)\b", "general_analysis"), | |
| (r"^what\b.*\b(can|do)\b.*\b(you|we)\b.*\b(find|learn|discover|see)", "general_analysis"), | |
| (r"^(give|provide)\b.*\b(overview|summary|insight|analysis)", "general_analysis"), | |
| (r"^(run|do|perform)\b.*\b(full|complete|comprehensive|end.to.end)\b.*\b(analysis|pipeline|workflow)", "full_pipeline"), | |
| (r"^(find|discover|uncover)\b.*\b(insight|pattern|trend|interesting)", "general_analysis"), | |
| ] | |
| class IntentClassifier: | |
| """ | |
| Classifies user intent into one of three execution modes. | |
| Uses pattern matching (no LLM call needed) for fast classification. | |
| Falls back to heuristics when patterns don't match. | |
| Usage: | |
| classifier = IntentClassifier() | |
| result = classifier.classify("Why are customers churning?") | |
| # IntentResult(mode="investigative", confidence=0.9, ...) | |
| result = classifier.classify("Make a scatter plot of age vs income") | |
| # IntentResult(mode="direct", confidence=0.95, ...) | |
| result = classifier.classify("Analyze this dataset") | |
| # IntentResult(mode="exploratory", confidence=0.85, ...) | |
| """ | |
| def classify( | |
| self, | |
| query: str, | |
| dataset_info: Optional[Dict[str, Any]] = None, | |
| has_target_col: bool = False | |
| ) -> IntentResult: | |
| """ | |
| Classify user intent into execution mode. | |
| Args: | |
| query: User's natural language query | |
| dataset_info: Optional dataset schema info | |
| has_target_col: Whether user provided a target column | |
| Returns: | |
| IntentResult with mode, confidence, and reasoning | |
| """ | |
| query_lower = query.lower().strip() | |
| # Phase 1: Check for DIRECT patterns (strongest evidence) | |
| direct_match = self._match_patterns(query_lower, DIRECT_PATTERNS) | |
| if direct_match: | |
| pattern, sub_intent = direct_match | |
| return IntentResult( | |
| mode="direct", | |
| confidence=0.90, | |
| reasoning=f"Direct command detected: {sub_intent} (pattern: {pattern[:50]})", | |
| sub_intent=sub_intent | |
| ) | |
| # Phase 2: Check for INVESTIGATIVE patterns | |
| invest_match = self._match_patterns(query_lower, INVESTIGATIVE_PATTERNS) | |
| if invest_match: | |
| pattern, sub_intent = invest_match | |
| return IntentResult( | |
| mode="investigative", | |
| confidence=0.85, | |
| reasoning=f"Analytical question detected: {sub_intent}", | |
| sub_intent=sub_intent | |
| ) | |
| # Phase 3: Check for EXPLORATORY patterns | |
| explore_match = self._match_patterns(query_lower, EXPLORATORY_PATTERNS) | |
| if explore_match: | |
| pattern, sub_intent = explore_match | |
| # Special case: "full pipeline" with target col β direct ML pipeline | |
| if sub_intent == "full_pipeline" and has_target_col: | |
| return IntentResult( | |
| mode="direct", | |
| confidence=0.85, | |
| reasoning="Full ML pipeline requested with target column", | |
| sub_intent="full_ml_pipeline" | |
| ) | |
| return IntentResult( | |
| mode="exploratory", | |
| confidence=0.80, | |
| reasoning=f"Open-ended analysis request: {sub_intent}", | |
| sub_intent=sub_intent | |
| ) | |
| # Phase 4: Heuristic fallback | |
| return self._heuristic_classify(query_lower, has_target_col) | |
| def _match_patterns(self, query: str, patterns: list) -> Optional[Tuple[str, str]]: | |
| """Try to match query against a list of (pattern, sub_intent) tuples.""" | |
| for pattern, sub_intent in patterns: | |
| if re.search(pattern, query, re.IGNORECASE): | |
| return (pattern, sub_intent) | |
| return None | |
| def _heuristic_classify(self, query: str, has_target_col: bool) -> IntentResult: | |
| """Fallback classification using simple heuristics.""" | |
| # Question words β investigative | |
| if query.startswith(("why", "how", "what", "which", "is there", "are there", "does", "do")): | |
| return IntentResult( | |
| mode="investigative", | |
| confidence=0.60, | |
| reasoning="Query starts with question word, likely analytical", | |
| sub_intent="general_question" | |
| ) | |
| # Very short queries β likely direct commands | |
| word_count = len(query.split()) | |
| if word_count <= 5: | |
| return IntentResult( | |
| mode="direct", | |
| confidence=0.55, | |
| reasoning="Short query, likely a direct command", | |
| sub_intent="short_command" | |
| ) | |
| # Has target column + action verbs β direct ML pipeline | |
| if has_target_col and any(w in query for w in ["predict", "train", "model", "classify", "regression"]): | |
| return IntentResult( | |
| mode="direct", | |
| confidence=0.75, | |
| reasoning="Target column provided with ML action verb", | |
| sub_intent="ml_pipeline" | |
| ) | |
| # Default: exploratory (safest default for data science) | |
| return IntentResult( | |
| mode="exploratory", | |
| confidence=0.40, | |
| reasoning="No strong pattern match, defaulting to exploratory analysis", | |
| sub_intent="default" | |
| ) | |
| def is_follow_up(query: str) -> bool: | |
| """ | |
| Detect if this is a follow-up question (uses context from previous analysis). | |
| Follow-ups should generally be INVESTIGATIVE (they're asking about | |
| something specific in the context of previous results). | |
| """ | |
| follow_up_patterns = [ | |
| r"^(now|next|also|and|then)\b", | |
| r"\b(the same|that|this|those|these)\b.*\b(data|model|result|plot|chart)", | |
| r"\b(more|another|different)\b.*\b(plot|chart|analysis|model)", | |
| r"\b(what about|how about|can you also)\b", | |
| r"\b(using|with)\b.*\b(the same|that|this)\b", | |
| ] | |
| query_lower = query.lower().strip() | |
| return any(re.search(p, query_lower) for p in follow_up_patterns) | |