| """
|
| Dataset Validator - checks JSONL training dataset quality.
|
|
|
| Validates format, structure, duplicates, length, diversity,
|
| and can auto-filter to produce a clean dataset.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import argparse
|
| import hashlib
|
| import json
|
| import os
|
| import re
|
| import sys
|
| from collections import Counter, defaultdict
|
| from pathlib import Path
|
| from typing import Any, Dict, List, Optional, Set, Tuple
|
|
|
| _THIS_DIR = Path(__file__).resolve().parent
|
| _PROJECT_ROOT = _THIS_DIR.parent
|
| if str(_PROJECT_ROOT) not in sys.path:
|
| sys.path.insert(0, str(_PROJECT_ROOT))
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _text_hash(text: str) -> str:
|
| """SHA-256 of normalised text for exact duplicate detection."""
|
| normalised = re.sub(r"\s+", " ", text.strip().lower())
|
| return hashlib.sha256(normalised.encode("utf-8")).hexdigest()
|
|
|
|
|
| def _word_set(text: str) -> Set[str]:
|
| """Set of lowercase words for Jaccard similarity."""
|
| return set(re.findall(r"[a-z]{2,}", text.lower()))
|
|
|
|
|
| def _jaccard_similarity(a: Set[str], b: Set[str]) -> float:
|
| if not a and not b:
|
| return 1.0
|
| union = a | b
|
| if not union:
|
| return 0.0
|
| return len(a & b) / len(union)
|
|
|
|
|
| def _extract_topic_words(text: str, top_n: int = 5) -> List[str]:
|
| """Extract dominant topic words from text."""
|
| stop = {
|
| "the", "a", "an", "is", "are", "was", "were", "be", "been",
|
| "have", "has", "had", "do", "does", "did", "will", "would",
|
| "to", "of", "in", "for", "on", "with", "at", "by", "from",
|
| "as", "and", "but", "or", "if", "that", "this", "what",
|
| "which", "it", "its", "they", "them", "their", "not", "you",
|
| "your", "can", "could", "should", "may", "might", "must",
|
| "how", "why", "when", "where", "who", "whom", "about",
|
| }
|
| words = re.findall(r"[a-z]{3,}", text.lower())
|
| filtered = [w for w in words if w not in stop]
|
| counts = Counter(filtered)
|
| return [w for w, _ in counts.most_common(top_n)]
|
|
|
|
|
|
|
|
|
|
|
|
|
| class ValidationIssue:
|
| """Represents a single validation problem."""
|
|
|
| def __init__(self, line_num: int, severity: str, code: str, message: str):
|
| self.line_num = line_num
|
| self.severity = severity
|
| self.code = code
|
| self.message = message
|
|
|
| def __repr__(self) -> str:
|
| return f"[{self.severity.upper()}] Line {self.line_num}: {self.code} - {self.message}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| class DatasetValidator:
|
| """Validate and clean JSONL training datasets."""
|
|
|
| REQUIRED_ROLES = {"system", "user", "assistant"}
|
|
|
| def __init__(
|
| self,
|
| min_response_length: int = 50,
|
| max_response_length: int = 10000,
|
| near_duplicate_threshold: float = 0.85,
|
| ):
|
| self.min_response_length = min_response_length
|
| self.max_response_length = max_response_length
|
| self.near_duplicate_threshold = near_duplicate_threshold
|
|
|
| def validate(self, filepath: str) -> Dict[str, Any]:
|
| """Validate a JSONL dataset file.
|
|
|
| Returns a comprehensive report dict with:
|
| - statistics (total, valid, invalid, duplicate, etc.)
|
| - issues list
|
| - per-line validity
|
| """
|
| filepath = Path(filepath)
|
| if not filepath.exists():
|
| raise FileNotFoundError(f"Dataset file not found: {filepath}")
|
|
|
| issues: List[ValidationIssue] = []
|
| entries: List[Dict[str, Any]] = []
|
| valid_entries: List[Dict[str, Any]] = []
|
| line_validity: List[bool] = []
|
|
|
|
|
| exact_hashes: Dict[str, int] = {}
|
| near_dup_sets: List[Tuple[int, Set[str]]] = []
|
|
|
|
|
| stats = {
|
| "total_lines": 0,
|
| "valid": 0,
|
| "invalid": 0,
|
| "parse_errors": 0,
|
| "missing_roles": 0,
|
| "exact_duplicates": 0,
|
| "near_duplicates": 0,
|
| "too_short": 0,
|
| "too_long": 0,
|
| "empty_content": 0,
|
| "response_lengths": [],
|
| "topic_words": [],
|
| }
|
|
|
| with open(filepath, "r", encoding="utf-8") as f:
|
| for line_num, raw_line in enumerate(f, start=1):
|
| stats["total_lines"] += 1
|
| raw_line = raw_line.strip()
|
|
|
| if not raw_line:
|
| issues.append(ValidationIssue(
|
| line_num, "warning", "EMPTY_LINE", "Empty line"
|
| ))
|
| line_validity.append(False)
|
| stats["invalid"] += 1
|
| continue
|
|
|
|
|
| try:
|
| entry = json.loads(raw_line)
|
| except json.JSONDecodeError as e:
|
| issues.append(ValidationIssue(
|
| line_num, "error", "PARSE_ERROR",
|
| f"Invalid JSON: {e}"
|
| ))
|
| line_validity.append(False)
|
| stats["parse_errors"] += 1
|
| stats["invalid"] += 1
|
| continue
|
|
|
| entries.append(entry)
|
| entry_valid = True
|
|
|
|
|
| messages = entry.get("messages")
|
| if not isinstance(messages, list):
|
| issues.append(ValidationIssue(
|
| line_num, "error", "NO_MESSAGES",
|
| "Missing or invalid 'messages' field"
|
| ))
|
| entry_valid = False
|
| stats["invalid"] += 1
|
| line_validity.append(False)
|
| continue
|
|
|
|
|
| roles_present = set()
|
| assistant_content = ""
|
| user_content = ""
|
| has_empty = False
|
|
|
| for msg in messages:
|
| role = msg.get("role", "")
|
| content = msg.get("content", "")
|
| roles_present.add(role)
|
|
|
| if role == "assistant":
|
| assistant_content = content or ""
|
| elif role == "user":
|
| user_content = content or ""
|
|
|
| if not content or not content.strip():
|
| has_empty = True
|
|
|
| missing_roles = self.REQUIRED_ROLES - roles_present
|
| if missing_roles:
|
| issues.append(ValidationIssue(
|
| line_num, "error", "MISSING_ROLES",
|
| f"Missing roles: {missing_roles}"
|
| ))
|
| entry_valid = False
|
| stats["missing_roles"] += 1
|
|
|
| if has_empty:
|
| issues.append(ValidationIssue(
|
| line_num, "warning", "EMPTY_CONTENT",
|
| "One or more messages have empty content"
|
| ))
|
| stats["empty_content"] += 1
|
|
|
|
|
| resp_len = len(assistant_content.split())
|
| stats["response_lengths"].append(resp_len)
|
|
|
| if resp_len < self.min_response_length:
|
| issues.append(ValidationIssue(
|
| line_num, "warning", "TOO_SHORT",
|
| f"Assistant response too short: {resp_len} words "
|
| f"(min: {self.min_response_length})"
|
| ))
|
| stats["too_short"] += 1
|
|
|
| if resp_len > self.max_response_length:
|
| issues.append(ValidationIssue(
|
| line_num, "warning", "TOO_LONG",
|
| f"Assistant response too long: {resp_len} words "
|
| f"(max: {self.max_response_length})"
|
| ))
|
| stats["too_long"] += 1
|
|
|
|
|
| combined_text = user_content + " " + assistant_content
|
| h = _text_hash(combined_text)
|
| if h in exact_hashes:
|
| issues.append(ValidationIssue(
|
| line_num, "warning", "EXACT_DUPLICATE",
|
| f"Exact duplicate of line {exact_hashes[h]}"
|
| ))
|
| stats["exact_duplicates"] += 1
|
| entry_valid = False
|
| else:
|
| exact_hashes[h] = line_num
|
|
|
|
|
| if user_content:
|
| user_words = _word_set(user_content)
|
| for prev_line, prev_words in near_dup_sets:
|
| sim = _jaccard_similarity(user_words, prev_words)
|
| if sim >= self.near_duplicate_threshold:
|
| issues.append(ValidationIssue(
|
| line_num, "info", "NEAR_DUPLICATE",
|
| f"Near-duplicate of line {prev_line} "
|
| f"(Jaccard: {sim:.3f})"
|
| ))
|
| stats["near_duplicates"] += 1
|
| break
|
| near_dup_sets.append((line_num, user_words))
|
|
|
|
|
| topic_words = _extract_topic_words(user_content + " " + assistant_content)
|
| stats["topic_words"].extend(topic_words)
|
|
|
| if entry_valid:
|
| stats["valid"] += 1
|
| valid_entries.append(entry)
|
| line_validity.append(True)
|
| else:
|
| stats["invalid"] += 1
|
| line_validity.append(False)
|
|
|
|
|
| topic_counts = Counter(stats["topic_words"])
|
| total_topics = len(set(stats["topic_words"]))
|
| top_topics = topic_counts.most_common(20)
|
|
|
|
|
| if topic_counts:
|
| top3_count = sum(c for _, c in topic_counts.most_common(3))
|
| total_count = sum(topic_counts.values())
|
| concentration = top3_count / total_count if total_count else 0
|
| else:
|
| concentration = 0
|
|
|
| if concentration > 0.5:
|
| top_kw = ", ".join(w for w, _ in topic_counts.most_common(3))
|
| issues.append(ValidationIssue(
|
| 0, "warning", "LOW_DIVERSITY",
|
| f"Dataset is concentrated on few topics ({concentration:.0%} "
|
| f"in top-3: {top_kw}). Consider adding more diverse examples."
|
| ))
|
|
|
|
|
| lengths = stats["response_lengths"]
|
| length_stats = {}
|
| if lengths:
|
| lengths_sorted = sorted(lengths)
|
| length_stats = {
|
| "min": lengths_sorted[0],
|
| "max": lengths_sorted[-1],
|
| "mean": round(sum(lengths) / len(lengths), 1),
|
| "median": lengths_sorted[len(lengths) // 2],
|
| "p10": lengths_sorted[int(len(lengths) * 0.1)],
|
| "p90": lengths_sorted[int(len(lengths) * 0.9)],
|
| }
|
|
|
| report = {
|
| "filepath": str(filepath),
|
| "total_lines": stats["total_lines"],
|
| "valid": stats["valid"],
|
| "invalid": stats["invalid"],
|
| "parse_errors": stats["parse_errors"],
|
| "missing_roles": stats["missing_roles"],
|
| "exact_duplicates": stats["exact_duplicates"],
|
| "near_duplicates": stats["near_duplicates"],
|
| "too_short": stats["too_short"],
|
| "too_long": stats["too_long"],
|
| "empty_content": stats["empty_content"],
|
| "unique_topics": total_topics,
|
| "topic_concentration": round(concentration, 4),
|
| "top_topics": top_topics,
|
| "response_length_stats": length_stats,
|
| "issues": issues,
|
| "line_validity": line_validity,
|
| "valid_entries": valid_entries,
|
| }
|
|
|
| return report
|
|
|
|
|
|
|
| def filter_dataset(
|
| self,
|
| filepath: str,
|
| output_path: str,
|
| remove_duplicates: bool = True,
|
| remove_short: bool = True,
|
| remove_long: bool = True,
|
| remove_invalid: bool = True,
|
| ) -> Dict[str, int]:
|
| """Validate and write a cleaned dataset.
|
|
|
| Returns stats about the filtering.
|
| """
|
| report = self.validate(filepath)
|
| issues_by_line: Dict[int, List[ValidationIssue]] = defaultdict(list)
|
| for issue in report["issues"]:
|
| issues_by_line[issue.line_num].append(issue)
|
|
|
| kept = 0
|
| removed = 0
|
| reasons: Dict[str, int] = defaultdict(int)
|
|
|
| os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
|
|
| with open(filepath, "r", encoding="utf-8") as fin, \
|
| open(output_path, "w", encoding="utf-8") as fout:
|
|
|
| seen_hashes: Set[str] = set()
|
|
|
| for line_num, raw_line in enumerate(fin, start=1):
|
| raw_line = raw_line.strip()
|
| if not raw_line:
|
| removed += 1
|
| reasons["empty_line"] += 1
|
| continue
|
|
|
| try:
|
| entry = json.loads(raw_line)
|
| except json.JSONDecodeError:
|
| if remove_invalid:
|
| removed += 1
|
| reasons["parse_error"] += 1
|
| continue
|
|
|
| messages = entry.get("messages", [])
|
| if not isinstance(messages, list):
|
| if remove_invalid:
|
| removed += 1
|
| reasons["no_messages"] += 1
|
| continue
|
|
|
| roles = {m.get("role") for m in messages}
|
| if self.REQUIRED_ROLES - roles:
|
| if remove_invalid:
|
| removed += 1
|
| reasons["missing_roles"] += 1
|
| continue
|
|
|
|
|
| assistant_text = ""
|
| user_text = ""
|
| for m in messages:
|
| if m.get("role") == "assistant":
|
| assistant_text = m.get("content", "")
|
| elif m.get("role") == "user":
|
| user_text = m.get("content", "")
|
|
|
|
|
| word_count = len(assistant_text.split())
|
| if remove_short and word_count < self.min_response_length:
|
| removed += 1
|
| reasons["too_short"] += 1
|
| continue
|
| if remove_long and word_count > self.max_response_length:
|
| removed += 1
|
| reasons["too_long"] += 1
|
| continue
|
|
|
|
|
| if remove_duplicates:
|
| h = _text_hash(user_text + " " + assistant_text)
|
| if h in seen_hashes:
|
| removed += 1
|
| reasons["duplicate"] += 1
|
| continue
|
| seen_hashes.add(h)
|
|
|
| fout.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
| kept += 1
|
|
|
| return {
|
| "input_lines": report["total_lines"],
|
| "kept": kept,
|
| "removed": removed,
|
| "removal_reasons": dict(reasons),
|
| }
|
|
|
|
|
|
|
| def format_report(self, report: Dict[str, Any]) -> str:
|
| """Format validation report as readable text."""
|
| lines: List[str] = []
|
| lines.append("=" * 70)
|
| lines.append(" DATASET VALIDATION REPORT")
|
| lines.append("=" * 70)
|
| lines.append(f" File: {report['filepath']}")
|
| lines.append("")
|
|
|
|
|
| lines.append("-" * 70)
|
| lines.append(" SUMMARY")
|
| lines.append("-" * 70)
|
| lines.append(f" Total lines: {report['total_lines']}")
|
| lines.append(f" Valid: {report['valid']}")
|
| lines.append(f" Invalid: {report['invalid']}")
|
| lines.append(f" Parse errors: {report['parse_errors']}")
|
| lines.append(f" Missing roles: {report['missing_roles']}")
|
| lines.append(f" Exact duplicates: {report['exact_duplicates']}")
|
| lines.append(f" Near duplicates: {report['near_duplicates']}")
|
| lines.append(f" Too short: {report['too_short']}")
|
| lines.append(f" Too long: {report['too_long']}")
|
| lines.append(f" Empty content: {report['empty_content']}")
|
|
|
|
|
| ls = report.get("response_length_stats", {})
|
| if ls:
|
| lines.append("")
|
| lines.append("-" * 70)
|
| lines.append(" RESPONSE LENGTH (words)")
|
| lines.append("-" * 70)
|
| lines.append(f" Min: {ls.get('min', 'N/A')}")
|
| lines.append(f" Max: {ls.get('max', 'N/A')}")
|
| lines.append(f" Mean: {ls.get('mean', 'N/A')}")
|
| lines.append(f" Median: {ls.get('median', 'N/A')}")
|
| lines.append(f" P10: {ls.get('p10', 'N/A')}")
|
| lines.append(f" P90: {ls.get('p90', 'N/A')}")
|
|
|
|
|
| lines.append("")
|
| lines.append("-" * 70)
|
| lines.append(" TOPIC DIVERSITY")
|
| lines.append("-" * 70)
|
| lines.append(f" Unique topic words: {report.get('unique_topics', 0)}")
|
| lines.append(f" Top-3 concentration: {report.get('topic_concentration', 0):.1%}")
|
| top_topics = report.get("top_topics", [])
|
| if top_topics:
|
| lines.append(" Top topics:")
|
| for word, count in top_topics[:10]:
|
| lines.append(f" {word:<20s} {count}")
|
|
|
|
|
| issues = report.get("issues", [])
|
| error_issues = [i for i in issues if i.severity == "error"]
|
| warning_issues = [i for i in issues if i.severity == "warning"]
|
|
|
| if error_issues:
|
| lines.append("")
|
| lines.append("-" * 70)
|
| lines.append(f" ERRORS ({len(error_issues)})")
|
| lines.append("-" * 70)
|
| for issue in error_issues[:20]:
|
| lines.append(f" {issue}")
|
| if len(error_issues) > 20:
|
| lines.append(f" ... and {len(error_issues) - 20} more errors")
|
|
|
| if warning_issues:
|
| lines.append("")
|
| lines.append("-" * 70)
|
| lines.append(f" WARNINGS ({len(warning_issues)})")
|
| lines.append("-" * 70)
|
| for issue in warning_issues[:20]:
|
| lines.append(f" {issue}")
|
| if len(warning_issues) > 20:
|
| lines.append(f" ... and {len(warning_issues) - 20} more warnings")
|
|
|
|
|
| lines.append("")
|
| lines.append("-" * 70)
|
| if (report["invalid"] == 0
|
| and report["exact_duplicates"] == 0
|
| and report.get("near_duplicates", 0) == 0
|
| and report.get("too_short", 0) == 0
|
| and report.get("empty_content", 0) == 0):
|
| lines.append(" VERDICT: PASS - Dataset is clean")
|
| elif report["invalid"] > report["total_lines"] * 0.1:
|
| lines.append(" VERDICT: FAIL - Too many invalid entries (>10%)")
|
| else:
|
| lines.append(" VERDICT: WARN - Some issues found, consider filtering")
|
| lines.append("-" * 70)
|
|
|
| lines.append("=" * 70)
|
| return "\n".join(lines)
|
|
|
|
|
|
|
|
|
|
|
|
|
| def main() -> None:
|
| parser = argparse.ArgumentParser(
|
| description="Codette Dataset Validator - check and clean JSONL training data"
|
| )
|
| parser.add_argument(
|
| "dataset",
|
| help="Path to JSONL dataset file",
|
| )
|
| parser.add_argument(
|
| "--filter", "-f",
|
| metavar="OUTPUT",
|
| default=None,
|
| help="Auto-filter and write clean dataset to OUTPUT path",
|
| )
|
| parser.add_argument(
|
| "--min-length",
|
| type=int,
|
| default=50,
|
| help="Minimum assistant response length in words (default: 50)",
|
| )
|
| parser.add_argument(
|
| "--max-length",
|
| type=int,
|
| default=10000,
|
| help="Maximum assistant response length in words (default: 10000)",
|
| )
|
| parser.add_argument(
|
| "--duplicate-threshold",
|
| type=float,
|
| default=0.85,
|
| help="Jaccard similarity threshold for near-duplicates (default: 0.85)",
|
| )
|
| parser.add_argument(
|
| "--json-report",
|
| metavar="PATH",
|
| default=None,
|
| help="Save report as JSON to this path",
|
| )
|
|
|
| args = parser.parse_args()
|
|
|
| validator = DatasetValidator(
|
| min_response_length=args.min_length,
|
| max_response_length=args.max_length,
|
| near_duplicate_threshold=args.duplicate_threshold,
|
| )
|
|
|
| print(f"Validating: {args.dataset}\n")
|
| report = validator.validate(args.dataset)
|
| print(validator.format_report(report))
|
|
|
| if args.json_report:
|
|
|
| save_report = {k: v for k, v in report.items()
|
| if k not in ("issues", "line_validity", "valid_entries")}
|
| save_report["issue_count"] = len(report["issues"])
|
| save_report["issues_summary"] = [repr(i) for i in report["issues"][:50]]
|
| os.makedirs(os.path.dirname(args.json_report) or ".", exist_ok=True)
|
| with open(args.json_report, "w", encoding="utf-8") as f:
|
| json.dump(save_report, f, indent=2, default=str)
|
| print(f"\nJSON report saved to: {args.json_report}")
|
|
|
| if args.filter:
|
| print(f"\nFiltering dataset -> {args.filter}")
|
| filter_stats = validator.filter_dataset(args.dataset, args.filter)
|
| print(f" Input lines: {filter_stats['input_lines']}")
|
| print(f" Kept: {filter_stats['kept']}")
|
| print(f" Removed: {filter_stats['removed']}")
|
| for reason, count in filter_stats["removal_reasons"].items():
|
| print(f" - {reason}: {count}")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|