Codette-Reasoning / evaluation /dataset_validator.py
Raiff1982's picture
Upload 120 files
ed1b365 verified
"""
Dataset Validator - checks JSONL training dataset quality.
Validates format, structure, duplicates, length, diversity,
and can auto-filter to produce a clean dataset.
"""
from __future__ import annotations
import argparse
import hashlib
import json
import os
import re
import sys
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
_THIS_DIR = Path(__file__).resolve().parent
_PROJECT_ROOT = _THIS_DIR.parent
if str(_PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(_PROJECT_ROOT))
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _text_hash(text: str) -> str:
"""SHA-256 of normalised text for exact duplicate detection."""
normalised = re.sub(r"\s+", " ", text.strip().lower())
return hashlib.sha256(normalised.encode("utf-8")).hexdigest()
def _word_set(text: str) -> Set[str]:
"""Set of lowercase words for Jaccard similarity."""
return set(re.findall(r"[a-z]{2,}", text.lower()))
def _jaccard_similarity(a: Set[str], b: Set[str]) -> float:
if not a and not b:
return 1.0
union = a | b
if not union:
return 0.0
return len(a & b) / len(union)
def _extract_topic_words(text: str, top_n: int = 5) -> List[str]:
"""Extract dominant topic words from text."""
stop = {
"the", "a", "an", "is", "are", "was", "were", "be", "been",
"have", "has", "had", "do", "does", "did", "will", "would",
"to", "of", "in", "for", "on", "with", "at", "by", "from",
"as", "and", "but", "or", "if", "that", "this", "what",
"which", "it", "its", "they", "them", "their", "not", "you",
"your", "can", "could", "should", "may", "might", "must",
"how", "why", "when", "where", "who", "whom", "about",
}
words = re.findall(r"[a-z]{3,}", text.lower())
filtered = [w for w in words if w not in stop]
counts = Counter(filtered)
return [w for w, _ in counts.most_common(top_n)]
# ---------------------------------------------------------------------------
# Validation Issue
# ---------------------------------------------------------------------------
class ValidationIssue:
"""Represents a single validation problem."""
def __init__(self, line_num: int, severity: str, code: str, message: str):
self.line_num = line_num
self.severity = severity # "error", "warning", "info"
self.code = code
self.message = message
def __repr__(self) -> str:
return f"[{self.severity.upper()}] Line {self.line_num}: {self.code} - {self.message}"
# ---------------------------------------------------------------------------
# DatasetValidator
# ---------------------------------------------------------------------------
class DatasetValidator:
"""Validate and clean JSONL training datasets."""
REQUIRED_ROLES = {"system", "user", "assistant"}
def __init__(
self,
min_response_length: int = 50,
max_response_length: int = 10000,
near_duplicate_threshold: float = 0.85,
):
self.min_response_length = min_response_length
self.max_response_length = max_response_length
self.near_duplicate_threshold = near_duplicate_threshold
def validate(self, filepath: str) -> Dict[str, Any]:
"""Validate a JSONL dataset file.
Returns a comprehensive report dict with:
- statistics (total, valid, invalid, duplicate, etc.)
- issues list
- per-line validity
"""
filepath = Path(filepath)
if not filepath.exists():
raise FileNotFoundError(f"Dataset file not found: {filepath}")
issues: List[ValidationIssue] = []
entries: List[Dict[str, Any]] = []
valid_entries: List[Dict[str, Any]] = []
line_validity: List[bool] = []
# Duplicate tracking
exact_hashes: Dict[str, int] = {} # hash -> first line
near_dup_sets: List[Tuple[int, Set[str]]] = []
# Stats
stats = {
"total_lines": 0,
"valid": 0,
"invalid": 0,
"parse_errors": 0,
"missing_roles": 0,
"exact_duplicates": 0,
"near_duplicates": 0,
"too_short": 0,
"too_long": 0,
"empty_content": 0,
"response_lengths": [],
"topic_words": [],
}
with open(filepath, "r", encoding="utf-8") as f:
for line_num, raw_line in enumerate(f, start=1):
stats["total_lines"] += 1
raw_line = raw_line.strip()
if not raw_line:
issues.append(ValidationIssue(
line_num, "warning", "EMPTY_LINE", "Empty line"
))
line_validity.append(False)
stats["invalid"] += 1
continue
# Parse JSON
try:
entry = json.loads(raw_line)
except json.JSONDecodeError as e:
issues.append(ValidationIssue(
line_num, "error", "PARSE_ERROR",
f"Invalid JSON: {e}"
))
line_validity.append(False)
stats["parse_errors"] += 1
stats["invalid"] += 1
continue
entries.append(entry)
entry_valid = True
# Check messages structure
messages = entry.get("messages")
if not isinstance(messages, list):
issues.append(ValidationIssue(
line_num, "error", "NO_MESSAGES",
"Missing or invalid 'messages' field"
))
entry_valid = False
stats["invalid"] += 1
line_validity.append(False)
continue
# Check roles
roles_present = set()
assistant_content = ""
user_content = ""
has_empty = False
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
roles_present.add(role)
if role == "assistant":
assistant_content = content or ""
elif role == "user":
user_content = content or ""
if not content or not content.strip():
has_empty = True
missing_roles = self.REQUIRED_ROLES - roles_present
if missing_roles:
issues.append(ValidationIssue(
line_num, "error", "MISSING_ROLES",
f"Missing roles: {missing_roles}"
))
entry_valid = False
stats["missing_roles"] += 1
if has_empty:
issues.append(ValidationIssue(
line_num, "warning", "EMPTY_CONTENT",
"One or more messages have empty content"
))
stats["empty_content"] += 1
# Response length
resp_len = len(assistant_content.split())
stats["response_lengths"].append(resp_len)
if resp_len < self.min_response_length:
issues.append(ValidationIssue(
line_num, "warning", "TOO_SHORT",
f"Assistant response too short: {resp_len} words "
f"(min: {self.min_response_length})"
))
stats["too_short"] += 1
if resp_len > self.max_response_length:
issues.append(ValidationIssue(
line_num, "warning", "TOO_LONG",
f"Assistant response too long: {resp_len} words "
f"(max: {self.max_response_length})"
))
stats["too_long"] += 1
# Exact duplicate check (on combined user+assistant)
combined_text = user_content + " " + assistant_content
h = _text_hash(combined_text)
if h in exact_hashes:
issues.append(ValidationIssue(
line_num, "warning", "EXACT_DUPLICATE",
f"Exact duplicate of line {exact_hashes[h]}"
))
stats["exact_duplicates"] += 1
entry_valid = False
else:
exact_hashes[h] = line_num
# Near-duplicate check (Jaccard on user prompt)
if user_content:
user_words = _word_set(user_content)
for prev_line, prev_words in near_dup_sets:
sim = _jaccard_similarity(user_words, prev_words)
if sim >= self.near_duplicate_threshold:
issues.append(ValidationIssue(
line_num, "info", "NEAR_DUPLICATE",
f"Near-duplicate of line {prev_line} "
f"(Jaccard: {sim:.3f})"
))
stats["near_duplicates"] += 1
break
near_dup_sets.append((line_num, user_words))
# Topic extraction
topic_words = _extract_topic_words(user_content + " " + assistant_content)
stats["topic_words"].extend(topic_words)
if entry_valid:
stats["valid"] += 1
valid_entries.append(entry)
line_validity.append(True)
else:
stats["invalid"] += 1
line_validity.append(False)
# Concept diversity
topic_counts = Counter(stats["topic_words"])
total_topics = len(set(stats["topic_words"]))
top_topics = topic_counts.most_common(20)
# Concentration ratio: if top-3 topics dominate, diversity is low
if topic_counts:
top3_count = sum(c for _, c in topic_counts.most_common(3))
total_count = sum(topic_counts.values())
concentration = top3_count / total_count if total_count else 0
else:
concentration = 0
if concentration > 0.5:
top_kw = ", ".join(w for w, _ in topic_counts.most_common(3))
issues.append(ValidationIssue(
0, "warning", "LOW_DIVERSITY",
f"Dataset is concentrated on few topics ({concentration:.0%} "
f"in top-3: {top_kw}). Consider adding more diverse examples."
))
# Build response length stats
lengths = stats["response_lengths"]
length_stats = {}
if lengths:
lengths_sorted = sorted(lengths)
length_stats = {
"min": lengths_sorted[0],
"max": lengths_sorted[-1],
"mean": round(sum(lengths) / len(lengths), 1),
"median": lengths_sorted[len(lengths) // 2],
"p10": lengths_sorted[int(len(lengths) * 0.1)],
"p90": lengths_sorted[int(len(lengths) * 0.9)],
}
report = {
"filepath": str(filepath),
"total_lines": stats["total_lines"],
"valid": stats["valid"],
"invalid": stats["invalid"],
"parse_errors": stats["parse_errors"],
"missing_roles": stats["missing_roles"],
"exact_duplicates": stats["exact_duplicates"],
"near_duplicates": stats["near_duplicates"],
"too_short": stats["too_short"],
"too_long": stats["too_long"],
"empty_content": stats["empty_content"],
"unique_topics": total_topics,
"topic_concentration": round(concentration, 4),
"top_topics": top_topics,
"response_length_stats": length_stats,
"issues": issues,
"line_validity": line_validity,
"valid_entries": valid_entries,
}
return report
# -- auto-filter -------------------------------------------------------
def filter_dataset(
self,
filepath: str,
output_path: str,
remove_duplicates: bool = True,
remove_short: bool = True,
remove_long: bool = True,
remove_invalid: bool = True,
) -> Dict[str, int]:
"""Validate and write a cleaned dataset.
Returns stats about the filtering.
"""
report = self.validate(filepath)
issues_by_line: Dict[int, List[ValidationIssue]] = defaultdict(list)
for issue in report["issues"]:
issues_by_line[issue.line_num].append(issue)
kept = 0
removed = 0
reasons: Dict[str, int] = defaultdict(int)
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
with open(filepath, "r", encoding="utf-8") as fin, \
open(output_path, "w", encoding="utf-8") as fout:
seen_hashes: Set[str] = set()
for line_num, raw_line in enumerate(fin, start=1):
raw_line = raw_line.strip()
if not raw_line:
removed += 1
reasons["empty_line"] += 1
continue
try:
entry = json.loads(raw_line)
except json.JSONDecodeError:
if remove_invalid:
removed += 1
reasons["parse_error"] += 1
continue
messages = entry.get("messages", [])
if not isinstance(messages, list):
if remove_invalid:
removed += 1
reasons["no_messages"] += 1
continue
roles = {m.get("role") for m in messages}
if self.REQUIRED_ROLES - roles:
if remove_invalid:
removed += 1
reasons["missing_roles"] += 1
continue
# Extract texts
assistant_text = ""
user_text = ""
for m in messages:
if m.get("role") == "assistant":
assistant_text = m.get("content", "")
elif m.get("role") == "user":
user_text = m.get("content", "")
# Length checks
word_count = len(assistant_text.split())
if remove_short and word_count < self.min_response_length:
removed += 1
reasons["too_short"] += 1
continue
if remove_long and word_count > self.max_response_length:
removed += 1
reasons["too_long"] += 1
continue
# Duplicate check
if remove_duplicates:
h = _text_hash(user_text + " " + assistant_text)
if h in seen_hashes:
removed += 1
reasons["duplicate"] += 1
continue
seen_hashes.add(h)
fout.write(json.dumps(entry, ensure_ascii=False) + "\n")
kept += 1
return {
"input_lines": report["total_lines"],
"kept": kept,
"removed": removed,
"removal_reasons": dict(reasons),
}
# -- report formatting -------------------------------------------------
def format_report(self, report: Dict[str, Any]) -> str:
"""Format validation report as readable text."""
lines: List[str] = []
lines.append("=" * 70)
lines.append(" DATASET VALIDATION REPORT")
lines.append("=" * 70)
lines.append(f" File: {report['filepath']}")
lines.append("")
# Summary
lines.append("-" * 70)
lines.append(" SUMMARY")
lines.append("-" * 70)
lines.append(f" Total lines: {report['total_lines']}")
lines.append(f" Valid: {report['valid']}")
lines.append(f" Invalid: {report['invalid']}")
lines.append(f" Parse errors: {report['parse_errors']}")
lines.append(f" Missing roles: {report['missing_roles']}")
lines.append(f" Exact duplicates: {report['exact_duplicates']}")
lines.append(f" Near duplicates: {report['near_duplicates']}")
lines.append(f" Too short: {report['too_short']}")
lines.append(f" Too long: {report['too_long']}")
lines.append(f" Empty content: {report['empty_content']}")
# Length stats
ls = report.get("response_length_stats", {})
if ls:
lines.append("")
lines.append("-" * 70)
lines.append(" RESPONSE LENGTH (words)")
lines.append("-" * 70)
lines.append(f" Min: {ls.get('min', 'N/A')}")
lines.append(f" Max: {ls.get('max', 'N/A')}")
lines.append(f" Mean: {ls.get('mean', 'N/A')}")
lines.append(f" Median: {ls.get('median', 'N/A')}")
lines.append(f" P10: {ls.get('p10', 'N/A')}")
lines.append(f" P90: {ls.get('p90', 'N/A')}")
# Diversity
lines.append("")
lines.append("-" * 70)
lines.append(" TOPIC DIVERSITY")
lines.append("-" * 70)
lines.append(f" Unique topic words: {report.get('unique_topics', 0)}")
lines.append(f" Top-3 concentration: {report.get('topic_concentration', 0):.1%}")
top_topics = report.get("top_topics", [])
if top_topics:
lines.append(" Top topics:")
for word, count in top_topics[:10]:
lines.append(f" {word:<20s} {count}")
# Issues
issues = report.get("issues", [])
error_issues = [i for i in issues if i.severity == "error"]
warning_issues = [i for i in issues if i.severity == "warning"]
if error_issues:
lines.append("")
lines.append("-" * 70)
lines.append(f" ERRORS ({len(error_issues)})")
lines.append("-" * 70)
for issue in error_issues[:20]:
lines.append(f" {issue}")
if len(error_issues) > 20:
lines.append(f" ... and {len(error_issues) - 20} more errors")
if warning_issues:
lines.append("")
lines.append("-" * 70)
lines.append(f" WARNINGS ({len(warning_issues)})")
lines.append("-" * 70)
for issue in warning_issues[:20]:
lines.append(f" {issue}")
if len(warning_issues) > 20:
lines.append(f" ... and {len(warning_issues) - 20} more warnings")
# Verdict
lines.append("")
lines.append("-" * 70)
if (report["invalid"] == 0
and report["exact_duplicates"] == 0
and report.get("near_duplicates", 0) == 0
and report.get("too_short", 0) == 0
and report.get("empty_content", 0) == 0):
lines.append(" VERDICT: PASS - Dataset is clean")
elif report["invalid"] > report["total_lines"] * 0.1:
lines.append(" VERDICT: FAIL - Too many invalid entries (>10%)")
else:
lines.append(" VERDICT: WARN - Some issues found, consider filtering")
lines.append("-" * 70)
lines.append("=" * 70)
return "\n".join(lines)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser(
description="Codette Dataset Validator - check and clean JSONL training data"
)
parser.add_argument(
"dataset",
help="Path to JSONL dataset file",
)
parser.add_argument(
"--filter", "-f",
metavar="OUTPUT",
default=None,
help="Auto-filter and write clean dataset to OUTPUT path",
)
parser.add_argument(
"--min-length",
type=int,
default=50,
help="Minimum assistant response length in words (default: 50)",
)
parser.add_argument(
"--max-length",
type=int,
default=10000,
help="Maximum assistant response length in words (default: 10000)",
)
parser.add_argument(
"--duplicate-threshold",
type=float,
default=0.85,
help="Jaccard similarity threshold for near-duplicates (default: 0.85)",
)
parser.add_argument(
"--json-report",
metavar="PATH",
default=None,
help="Save report as JSON to this path",
)
args = parser.parse_args()
validator = DatasetValidator(
min_response_length=args.min_length,
max_response_length=args.max_length,
near_duplicate_threshold=args.duplicate_threshold,
)
print(f"Validating: {args.dataset}\n")
report = validator.validate(args.dataset)
print(validator.format_report(report))
if args.json_report:
# Remove non-serialisable items
save_report = {k: v for k, v in report.items()
if k not in ("issues", "line_validity", "valid_entries")}
save_report["issue_count"] = len(report["issues"])
save_report["issues_summary"] = [repr(i) for i in report["issues"][:50]]
os.makedirs(os.path.dirname(args.json_report) or ".", exist_ok=True)
with open(args.json_report, "w", encoding="utf-8") as f:
json.dump(save_report, f, indent=2, default=str)
print(f"\nJSON report saved to: {args.json_report}")
if args.filter:
print(f"\nFiltering dataset -> {args.filter}")
filter_stats = validator.filter_dataset(args.dataset, args.filter)
print(f" Input lines: {filter_stats['input_lines']}")
print(f" Kept: {filter_stats['kept']}")
print(f" Removed: {filter_stats['removed']}")
for reason, count in filter_stats["removal_reasons"].items():
print(f" - {reason}: {count}")
if __name__ == "__main__":
main()