Spaces:
Running
Running
| """ | |
| Dataset definitions for the LexiMind multitask training pipeline. | |
| Defines PyTorch Dataset classes and data loading utilities for summarization, | |
| emotion classification, and topic classification tasks. Supports both JSON | |
| array and JSONL file formats. | |
| Author: Oliver Perrin | |
| Date: December 2025 | |
| """ | |
| from __future__ import annotations | |
| import hashlib | |
| import json | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Callable, Dict, Iterable, List, Sequence, Set, TypeVar | |
| from sklearn.preprocessing import LabelEncoder, MultiLabelBinarizer | |
| from torch.utils.data import Dataset | |
| class SummarizationExample: | |
| """Container for abstractive summarization samples.""" | |
| source: str | |
| summary: str | |
| class EmotionExample: | |
| """Container for multi-label emotion classification samples.""" | |
| text: str | |
| emotions: Sequence[str] | |
| class TopicExample: | |
| """Container for topic clustering / classification samples.""" | |
| text: str | |
| topic: str | |
| class SummarizationDataset(Dataset[SummarizationExample]): | |
| """Dataset yielding encoder-decoder training pairs.""" | |
| def __init__(self, examples: Iterable[SummarizationExample]) -> None: | |
| self._examples = list(examples) | |
| def __len__(self) -> int: | |
| return len(self._examples) | |
| def __getitem__(self, index: int) -> SummarizationExample: | |
| return self._examples[index] | |
| class EmotionDataset(Dataset[EmotionExample]): | |
| """Dataset that owns a scikit-learn MultiLabelBinarizer for emissions.""" | |
| def __init__( | |
| self, | |
| examples: Iterable[EmotionExample], | |
| *, | |
| binarizer: MultiLabelBinarizer | None = None, | |
| ) -> None: | |
| self._examples = list(examples) | |
| all_labels = [example.emotions for example in self._examples] | |
| if binarizer is None: | |
| self._binarizer = MultiLabelBinarizer() | |
| self._binarizer.fit(all_labels) | |
| else: | |
| self._binarizer = binarizer | |
| if not hasattr(self._binarizer, "classes_"): | |
| raise ValueError( | |
| "Provided MultiLabelBinarizer must be pre-fitted with 'classes_' attribute." | |
| ) | |
| def __len__(self) -> int: | |
| return len(self._examples) | |
| def __getitem__(self, index: int) -> EmotionExample: | |
| return self._examples[index] | |
| def binarizer(self) -> MultiLabelBinarizer: | |
| return self._binarizer | |
| def emotion_classes(self) -> List[str]: | |
| return list(self._binarizer.classes_) | |
| class TopicDataset(Dataset[TopicExample]): | |
| """Dataset that owns a LabelEncoder for topic ids.""" | |
| def __init__( | |
| self, | |
| examples: Iterable[TopicExample], | |
| *, | |
| encoder: LabelEncoder | None = None, | |
| ) -> None: | |
| self._examples = list(examples) | |
| topics = [example.topic for example in self._examples] | |
| if encoder is None: | |
| self._encoder = LabelEncoder().fit(topics) | |
| else: | |
| self._encoder = encoder | |
| if not hasattr(self._encoder, "classes_"): | |
| raise ValueError( | |
| "Provided LabelEncoder must be pre-fitted with 'classes_' attribute." | |
| ) | |
| def __len__(self) -> int: | |
| return len(self._examples) | |
| def __getitem__(self, index: int) -> TopicExample: | |
| return self._examples[index] | |
| def encoder(self) -> LabelEncoder: | |
| return self._encoder | |
| def topic_classes(self) -> List[str]: | |
| return list(self._encoder.classes_) | |
| T = TypeVar("T") | |
| def _safe_json_load(handle, path: Path) -> object: | |
| try: | |
| return json.load(handle) | |
| except json.JSONDecodeError as exc: | |
| raise ValueError(f"Failed to parse JSON in '{path}': {exc}") from exc | |
| def _safe_json_loads(data: str, path: Path, line_number: int) -> object: | |
| try: | |
| return json.loads(data) | |
| except json.JSONDecodeError as exc: | |
| raise ValueError(f"Failed to parse JSON in '{path}' at line {line_number}: {exc}") from exc | |
| def _validate_keys( | |
| payload: dict, | |
| required_keys: Sequence[str], | |
| position: int, | |
| *, | |
| path: Path, | |
| is_array: bool = False, | |
| ) -> None: | |
| missing = [key for key in required_keys if key not in payload] | |
| if missing: | |
| keys = ", ".join(sorted(missing)) | |
| location = "index" if is_array else "line" | |
| raise KeyError(f"Missing required keys ({keys}) at {location} {position} of '{path}'") | |
| def _load_jsonl_generic( | |
| path: str, | |
| constructor: Callable[[dict], T], | |
| required_keys: Sequence[str], | |
| ) -> List[T]: | |
| data_path = Path(path) | |
| if not data_path.exists(): | |
| raise FileNotFoundError(f"Dataset file '{data_path}' does not exist") | |
| if not data_path.is_file(): | |
| raise ValueError(f"Dataset path '{data_path}' is not a file") | |
| items: List[T] = [] | |
| with data_path.open("r", encoding="utf-8") as handle: | |
| first_non_ws = "" | |
| while True: | |
| pos = handle.tell() | |
| char = handle.read(1) | |
| if not char: | |
| break | |
| if not char.isspace(): | |
| first_non_ws = char | |
| handle.seek(pos) | |
| break | |
| if not first_non_ws: | |
| raise ValueError(f"Dataset file '{data_path}' is empty or contains only whitespace") | |
| if first_non_ws == "[": | |
| payloads = _safe_json_load(handle, data_path) | |
| if not isinstance(payloads, list): | |
| raise ValueError( | |
| f"Expected a JSON array in '{data_path}' but found {type(payloads).__name__}" | |
| ) | |
| for idx, payload in enumerate(payloads): | |
| if not isinstance(payload, dict): | |
| raise ValueError( | |
| f"Expected objects in array for '{data_path}', found {type(payload).__name__} at index {idx}" | |
| ) | |
| _validate_keys(payload, required_keys, idx, path=data_path, is_array=True) | |
| items.append(constructor(payload)) | |
| else: | |
| handle.seek(0) | |
| line_number = 0 | |
| for line in handle: | |
| line_number += 1 | |
| if not line.strip(): | |
| continue | |
| payload = _safe_json_loads(line, data_path, line_number) | |
| if not isinstance(payload, dict): | |
| raise ValueError( | |
| f"Expected JSON object per line in '{data_path}', found {type(payload).__name__} at line {line_number}" | |
| ) | |
| _validate_keys(payload, required_keys, line_number, path=data_path) | |
| items.append(constructor(payload)) | |
| return items | |
| def load_summarization_jsonl(path: str) -> List[SummarizationExample]: | |
| return _load_jsonl_generic( | |
| path, | |
| lambda payload: SummarizationExample(source=payload["source"], summary=payload["summary"]), | |
| required_keys=("source", "summary"), | |
| ) | |
| def load_emotion_jsonl(path: str) -> List[EmotionExample]: | |
| return _load_jsonl_generic( | |
| path, | |
| lambda payload: EmotionExample(text=payload["text"], emotions=payload.get("emotions", [])), | |
| required_keys=("text",), | |
| ) | |
| def load_topic_jsonl(path: str) -> List[TopicExample]: | |
| return _load_jsonl_generic( | |
| path, | |
| lambda payload: TopicExample(text=payload["text"], topic=payload["topic"]), | |
| required_keys=("text", "topic"), | |
| ) | |
| # --------------- Cross-Task Deduplication --------------- | |
| def _text_fingerprint(text: str, n_chars: int = 200) -> str: | |
| """Create a stable fingerprint from the first N characters of text. | |
| Uses a hash of the normalized (lowered, whitespace-collapsed) prefix | |
| to detect document-level overlap across tasks. | |
| """ | |
| normalized = " ".join(text.lower().split())[:n_chars] | |
| return hashlib.md5(normalized.encode("utf-8")).hexdigest() | |
| def deduplicate_across_tasks( | |
| summ_examples: List[SummarizationExample], | |
| topic_examples: List[TopicExample], | |
| emotion_examples: List[EmotionExample] | None = None, | |
| ) -> Dict[str, int]: | |
| """Detect and report cross-task document overlap. | |
| Checks whether texts appearing in the summarization dataset also appear | |
| in the topic or emotion datasets, which could create data leakage in MTL. | |
| Returns: | |
| Dict with overlap counts between task pairs. | |
| """ | |
| summ_fps: Set[str] = {_text_fingerprint(ex.source) for ex in summ_examples} | |
| topic_fps: Set[str] = {_text_fingerprint(ex.text) for ex in topic_examples} | |
| overlap: Dict[str, int] = { | |
| "summ_topic_overlap": len(summ_fps & topic_fps), | |
| "summ_total": len(summ_fps), | |
| "topic_total": len(topic_fps), | |
| } | |
| if emotion_examples: | |
| emot_fps: Set[str] = {_text_fingerprint(ex.text) for ex in emotion_examples} | |
| overlap["summ_emotion_overlap"] = len(summ_fps & emot_fps) | |
| overlap["topic_emotion_overlap"] = len(topic_fps & emot_fps) | |
| overlap["emotion_total"] = len(emot_fps) | |
| return overlap | |
| def remove_overlapping_examples( | |
| primary_examples: List[TopicExample], | |
| reference_examples: List[SummarizationExample], | |
| split: str = "val", | |
| ) -> tuple[List[TopicExample], int]: | |
| """Remove topic examples whose texts overlap with summarization data. | |
| This prevents cross-task data leakage where a document seen during | |
| summarization training could boost topic classification on validation/test. | |
| Args: | |
| primary_examples: Topic examples to filter | |
| reference_examples: Summarization examples to check against | |
| split: Name of split being processed (for logging) | |
| Returns: | |
| Tuple of (filtered_examples, num_removed) | |
| """ | |
| ref_fps = {_text_fingerprint(ex.source) for ex in reference_examples} | |
| filtered = [] | |
| removed = 0 | |
| for ex in primary_examples: | |
| fp = _text_fingerprint(ex.text) | |
| if fp in ref_fps: | |
| removed += 1 | |
| else: | |
| filtered.append(ex) | |
| if removed > 0: | |
| print(f" Dedup: removed {removed} overlapping examples from topic {split}") | |
| return filtered, removed | |