| """ |
| Dataset Validator - checks JSONL training dataset quality. |
| |
| Validates format, structure, duplicates, length, diversity, |
| and can auto-filter to produce a clean dataset. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import hashlib |
| import json |
| import os |
| import re |
| import sys |
| from collections import Counter, defaultdict |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Set, Tuple |
|
|
| _THIS_DIR = Path(__file__).resolve().parent |
| _PROJECT_ROOT = _THIS_DIR.parent |
| if str(_PROJECT_ROOT) not in sys.path: |
| sys.path.insert(0, str(_PROJECT_ROOT)) |
|
|
|
|
| |
| |
| |
|
|
| def _text_hash(text: str) -> str: |
| """SHA-256 of normalised text for exact duplicate detection.""" |
| normalised = re.sub(r"\s+", " ", text.strip().lower()) |
| return hashlib.sha256(normalised.encode("utf-8")).hexdigest() |
|
|
|
|
| def _word_set(text: str) -> Set[str]: |
| """Set of lowercase words for Jaccard similarity.""" |
| return set(re.findall(r"[a-z]{2,}", text.lower())) |
|
|
|
|
| def _jaccard_similarity(a: Set[str], b: Set[str]) -> float: |
| if not a and not b: |
| return 1.0 |
| union = a | b |
| if not union: |
| return 0.0 |
| return len(a & b) / len(union) |
|
|
|
|
| def _extract_topic_words(text: str, top_n: int = 5) -> List[str]: |
| """Extract dominant topic words from text.""" |
| stop = { |
| "the", "a", "an", "is", "are", "was", "were", "be", "been", |
| "have", "has", "had", "do", "does", "did", "will", "would", |
| "to", "of", "in", "for", "on", "with", "at", "by", "from", |
| "as", "and", "but", "or", "if", "that", "this", "what", |
| "which", "it", "its", "they", "them", "their", "not", "you", |
| "your", "can", "could", "should", "may", "might", "must", |
| "how", "why", "when", "where", "who", "whom", "about", |
| } |
| words = re.findall(r"[a-z]{3,}", text.lower()) |
| filtered = [w for w in words if w not in stop] |
| counts = Counter(filtered) |
| return [w for w, _ in counts.most_common(top_n)] |
|
|
|
|
| |
| |
| |
|
|
| class ValidationIssue: |
| """Represents a single validation problem.""" |
|
|
| def __init__(self, line_num: int, severity: str, code: str, message: str): |
| self.line_num = line_num |
| self.severity = severity |
| self.code = code |
| self.message = message |
|
|
| def __repr__(self) -> str: |
| return f"[{self.severity.upper()}] Line {self.line_num}: {self.code} - {self.message}" |
|
|
|
|
| |
| |
| |
|
|
| class DatasetValidator: |
| """Validate and clean JSONL training datasets.""" |
|
|
| REQUIRED_ROLES = {"system", "user", "assistant"} |
|
|
| def __init__( |
| self, |
| min_response_length: int = 50, |
| max_response_length: int = 10000, |
| near_duplicate_threshold: float = 0.85, |
| ): |
| self.min_response_length = min_response_length |
| self.max_response_length = max_response_length |
| self.near_duplicate_threshold = near_duplicate_threshold |
|
|
| def validate(self, filepath: str) -> Dict[str, Any]: |
| """Validate a JSONL dataset file. |
| |
| Returns a comprehensive report dict with: |
| - statistics (total, valid, invalid, duplicate, etc.) |
| - issues list |
| - per-line validity |
| """ |
| filepath = Path(filepath) |
| if not filepath.exists(): |
| raise FileNotFoundError(f"Dataset file not found: {filepath}") |
|
|
| issues: List[ValidationIssue] = [] |
| entries: List[Dict[str, Any]] = [] |
| valid_entries: List[Dict[str, Any]] = [] |
| line_validity: List[bool] = [] |
|
|
| |
| exact_hashes: Dict[str, int] = {} |
| near_dup_sets: List[Tuple[int, Set[str]]] = [] |
|
|
| |
| stats = { |
| "total_lines": 0, |
| "valid": 0, |
| "invalid": 0, |
| "parse_errors": 0, |
| "missing_roles": 0, |
| "exact_duplicates": 0, |
| "near_duplicates": 0, |
| "too_short": 0, |
| "too_long": 0, |
| "empty_content": 0, |
| "response_lengths": [], |
| "topic_words": [], |
| } |
|
|
| with open(filepath, "r", encoding="utf-8") as f: |
| for line_num, raw_line in enumerate(f, start=1): |
| stats["total_lines"] += 1 |
| raw_line = raw_line.strip() |
|
|
| if not raw_line: |
| issues.append(ValidationIssue( |
| line_num, "warning", "EMPTY_LINE", "Empty line" |
| )) |
| line_validity.append(False) |
| stats["invalid"] += 1 |
| continue |
|
|
| |
| try: |
| entry = json.loads(raw_line) |
| except json.JSONDecodeError as e: |
| issues.append(ValidationIssue( |
| line_num, "error", "PARSE_ERROR", |
| f"Invalid JSON: {e}" |
| )) |
| line_validity.append(False) |
| stats["parse_errors"] += 1 |
| stats["invalid"] += 1 |
| continue |
|
|
| entries.append(entry) |
| entry_valid = True |
|
|
| |
| messages = entry.get("messages") |
| if not isinstance(messages, list): |
| issues.append(ValidationIssue( |
| line_num, "error", "NO_MESSAGES", |
| "Missing or invalid 'messages' field" |
| )) |
| entry_valid = False |
| stats["invalid"] += 1 |
| line_validity.append(False) |
| continue |
|
|
| |
| roles_present = set() |
| assistant_content = "" |
| user_content = "" |
| has_empty = False |
|
|
| for msg in messages: |
| role = msg.get("role", "") |
| content = msg.get("content", "") |
| roles_present.add(role) |
|
|
| if role == "assistant": |
| assistant_content = content or "" |
| elif role == "user": |
| user_content = content or "" |
|
|
| if not content or not content.strip(): |
| has_empty = True |
|
|
| missing_roles = self.REQUIRED_ROLES - roles_present |
| if missing_roles: |
| issues.append(ValidationIssue( |
| line_num, "error", "MISSING_ROLES", |
| f"Missing roles: {missing_roles}" |
| )) |
| entry_valid = False |
| stats["missing_roles"] += 1 |
|
|
| if has_empty: |
| issues.append(ValidationIssue( |
| line_num, "warning", "EMPTY_CONTENT", |
| "One or more messages have empty content" |
| )) |
| stats["empty_content"] += 1 |
|
|
| |
| resp_len = len(assistant_content.split()) |
| stats["response_lengths"].append(resp_len) |
|
|
| if resp_len < self.min_response_length: |
| issues.append(ValidationIssue( |
| line_num, "warning", "TOO_SHORT", |
| f"Assistant response too short: {resp_len} words " |
| f"(min: {self.min_response_length})" |
| )) |
| stats["too_short"] += 1 |
|
|
| if resp_len > self.max_response_length: |
| issues.append(ValidationIssue( |
| line_num, "warning", "TOO_LONG", |
| f"Assistant response too long: {resp_len} words " |
| f"(max: {self.max_response_length})" |
| )) |
| stats["too_long"] += 1 |
|
|
| |
| combined_text = user_content + " " + assistant_content |
| h = _text_hash(combined_text) |
| if h in exact_hashes: |
| issues.append(ValidationIssue( |
| line_num, "warning", "EXACT_DUPLICATE", |
| f"Exact duplicate of line {exact_hashes[h]}" |
| )) |
| stats["exact_duplicates"] += 1 |
| entry_valid = False |
| else: |
| exact_hashes[h] = line_num |
|
|
| |
| if user_content: |
| user_words = _word_set(user_content) |
| for prev_line, prev_words in near_dup_sets: |
| sim = _jaccard_similarity(user_words, prev_words) |
| if sim >= self.near_duplicate_threshold: |
| issues.append(ValidationIssue( |
| line_num, "info", "NEAR_DUPLICATE", |
| f"Near-duplicate of line {prev_line} " |
| f"(Jaccard: {sim:.3f})" |
| )) |
| stats["near_duplicates"] += 1 |
| break |
| near_dup_sets.append((line_num, user_words)) |
|
|
| |
| topic_words = _extract_topic_words(user_content + " " + assistant_content) |
| stats["topic_words"].extend(topic_words) |
|
|
| if entry_valid: |
| stats["valid"] += 1 |
| valid_entries.append(entry) |
| line_validity.append(True) |
| else: |
| stats["invalid"] += 1 |
| line_validity.append(False) |
|
|
| |
| topic_counts = Counter(stats["topic_words"]) |
| total_topics = len(set(stats["topic_words"])) |
| top_topics = topic_counts.most_common(20) |
|
|
| |
| if topic_counts: |
| top3_count = sum(c for _, c in topic_counts.most_common(3)) |
| total_count = sum(topic_counts.values()) |
| concentration = top3_count / total_count if total_count else 0 |
| else: |
| concentration = 0 |
|
|
| if concentration > 0.5: |
| top_kw = ", ".join(w for w, _ in topic_counts.most_common(3)) |
| issues.append(ValidationIssue( |
| 0, "warning", "LOW_DIVERSITY", |
| f"Dataset is concentrated on few topics ({concentration:.0%} " |
| f"in top-3: {top_kw}). Consider adding more diverse examples." |
| )) |
|
|
| |
| lengths = stats["response_lengths"] |
| length_stats = {} |
| if lengths: |
| lengths_sorted = sorted(lengths) |
| length_stats = { |
| "min": lengths_sorted[0], |
| "max": lengths_sorted[-1], |
| "mean": round(sum(lengths) / len(lengths), 1), |
| "median": lengths_sorted[len(lengths) // 2], |
| "p10": lengths_sorted[int(len(lengths) * 0.1)], |
| "p90": lengths_sorted[int(len(lengths) * 0.9)], |
| } |
|
|
| report = { |
| "filepath": str(filepath), |
| "total_lines": stats["total_lines"], |
| "valid": stats["valid"], |
| "invalid": stats["invalid"], |
| "parse_errors": stats["parse_errors"], |
| "missing_roles": stats["missing_roles"], |
| "exact_duplicates": stats["exact_duplicates"], |
| "near_duplicates": stats["near_duplicates"], |
| "too_short": stats["too_short"], |
| "too_long": stats["too_long"], |
| "empty_content": stats["empty_content"], |
| "unique_topics": total_topics, |
| "topic_concentration": round(concentration, 4), |
| "top_topics": top_topics, |
| "response_length_stats": length_stats, |
| "issues": issues, |
| "line_validity": line_validity, |
| "valid_entries": valid_entries, |
| } |
|
|
| return report |
|
|
| |
|
|
| def filter_dataset( |
| self, |
| filepath: str, |
| output_path: str, |
| remove_duplicates: bool = True, |
| remove_short: bool = True, |
| remove_long: bool = True, |
| remove_invalid: bool = True, |
| ) -> Dict[str, int]: |
| """Validate and write a cleaned dataset. |
| |
| Returns stats about the filtering. |
| """ |
| report = self.validate(filepath) |
| issues_by_line: Dict[int, List[ValidationIssue]] = defaultdict(list) |
| for issue in report["issues"]: |
| issues_by_line[issue.line_num].append(issue) |
|
|
| kept = 0 |
| removed = 0 |
| reasons: Dict[str, int] = defaultdict(int) |
|
|
| os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) |
|
|
| with open(filepath, "r", encoding="utf-8") as fin, \ |
| open(output_path, "w", encoding="utf-8") as fout: |
|
|
| seen_hashes: Set[str] = set() |
|
|
| for line_num, raw_line in enumerate(fin, start=1): |
| raw_line = raw_line.strip() |
| if not raw_line: |
| removed += 1 |
| reasons["empty_line"] += 1 |
| continue |
|
|
| try: |
| entry = json.loads(raw_line) |
| except json.JSONDecodeError: |
| if remove_invalid: |
| removed += 1 |
| reasons["parse_error"] += 1 |
| continue |
|
|
| messages = entry.get("messages", []) |
| if not isinstance(messages, list): |
| if remove_invalid: |
| removed += 1 |
| reasons["no_messages"] += 1 |
| continue |
|
|
| roles = {m.get("role") for m in messages} |
| if self.REQUIRED_ROLES - roles: |
| if remove_invalid: |
| removed += 1 |
| reasons["missing_roles"] += 1 |
| continue |
|
|
| |
| assistant_text = "" |
| user_text = "" |
| for m in messages: |
| if m.get("role") == "assistant": |
| assistant_text = m.get("content", "") |
| elif m.get("role") == "user": |
| user_text = m.get("content", "") |
|
|
| |
| word_count = len(assistant_text.split()) |
| if remove_short and word_count < self.min_response_length: |
| removed += 1 |
| reasons["too_short"] += 1 |
| continue |
| if remove_long and word_count > self.max_response_length: |
| removed += 1 |
| reasons["too_long"] += 1 |
| continue |
|
|
| |
| if remove_duplicates: |
| h = _text_hash(user_text + " " + assistant_text) |
| if h in seen_hashes: |
| removed += 1 |
| reasons["duplicate"] += 1 |
| continue |
| seen_hashes.add(h) |
|
|
| fout.write(json.dumps(entry, ensure_ascii=False) + "\n") |
| kept += 1 |
|
|
| return { |
| "input_lines": report["total_lines"], |
| "kept": kept, |
| "removed": removed, |
| "removal_reasons": dict(reasons), |
| } |
|
|
| |
|
|
| def format_report(self, report: Dict[str, Any]) -> str: |
| """Format validation report as readable text.""" |
| lines: List[str] = [] |
| lines.append("=" * 70) |
| lines.append(" DATASET VALIDATION REPORT") |
| lines.append("=" * 70) |
| lines.append(f" File: {report['filepath']}") |
| lines.append("") |
|
|
| |
| lines.append("-" * 70) |
| lines.append(" SUMMARY") |
| lines.append("-" * 70) |
| lines.append(f" Total lines: {report['total_lines']}") |
| lines.append(f" Valid: {report['valid']}") |
| lines.append(f" Invalid: {report['invalid']}") |
| lines.append(f" Parse errors: {report['parse_errors']}") |
| lines.append(f" Missing roles: {report['missing_roles']}") |
| lines.append(f" Exact duplicates: {report['exact_duplicates']}") |
| lines.append(f" Near duplicates: {report['near_duplicates']}") |
| lines.append(f" Too short: {report['too_short']}") |
| lines.append(f" Too long: {report['too_long']}") |
| lines.append(f" Empty content: {report['empty_content']}") |
|
|
| |
| ls = report.get("response_length_stats", {}) |
| if ls: |
| lines.append("") |
| lines.append("-" * 70) |
| lines.append(" RESPONSE LENGTH (words)") |
| lines.append("-" * 70) |
| lines.append(f" Min: {ls.get('min', 'N/A')}") |
| lines.append(f" Max: {ls.get('max', 'N/A')}") |
| lines.append(f" Mean: {ls.get('mean', 'N/A')}") |
| lines.append(f" Median: {ls.get('median', 'N/A')}") |
| lines.append(f" P10: {ls.get('p10', 'N/A')}") |
| lines.append(f" P90: {ls.get('p90', 'N/A')}") |
|
|
| |
| lines.append("") |
| lines.append("-" * 70) |
| lines.append(" TOPIC DIVERSITY") |
| lines.append("-" * 70) |
| lines.append(f" Unique topic words: {report.get('unique_topics', 0)}") |
| lines.append(f" Top-3 concentration: {report.get('topic_concentration', 0):.1%}") |
| top_topics = report.get("top_topics", []) |
| if top_topics: |
| lines.append(" Top topics:") |
| for word, count in top_topics[:10]: |
| lines.append(f" {word:<20s} {count}") |
|
|
| |
| issues = report.get("issues", []) |
| error_issues = [i for i in issues if i.severity == "error"] |
| warning_issues = [i for i in issues if i.severity == "warning"] |
|
|
| if error_issues: |
| lines.append("") |
| lines.append("-" * 70) |
| lines.append(f" ERRORS ({len(error_issues)})") |
| lines.append("-" * 70) |
| for issue in error_issues[:20]: |
| lines.append(f" {issue}") |
| if len(error_issues) > 20: |
| lines.append(f" ... and {len(error_issues) - 20} more errors") |
|
|
| if warning_issues: |
| lines.append("") |
| lines.append("-" * 70) |
| lines.append(f" WARNINGS ({len(warning_issues)})") |
| lines.append("-" * 70) |
| for issue in warning_issues[:20]: |
| lines.append(f" {issue}") |
| if len(warning_issues) > 20: |
| lines.append(f" ... and {len(warning_issues) - 20} more warnings") |
|
|
| |
| lines.append("") |
| lines.append("-" * 70) |
| if (report["invalid"] == 0 |
| and report["exact_duplicates"] == 0 |
| and report.get("near_duplicates", 0) == 0 |
| and report.get("too_short", 0) == 0 |
| and report.get("empty_content", 0) == 0): |
| lines.append(" VERDICT: PASS - Dataset is clean") |
| elif report["invalid"] > report["total_lines"] * 0.1: |
| lines.append(" VERDICT: FAIL - Too many invalid entries (>10%)") |
| else: |
| lines.append(" VERDICT: WARN - Some issues found, consider filtering") |
| lines.append("-" * 70) |
|
|
| lines.append("=" * 70) |
| return "\n".join(lines) |
|
|
|
|
| |
| |
| |
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser( |
| description="Codette Dataset Validator - check and clean JSONL training data" |
| ) |
| parser.add_argument( |
| "dataset", |
| help="Path to JSONL dataset file", |
| ) |
| parser.add_argument( |
| "--filter", "-f", |
| metavar="OUTPUT", |
| default=None, |
| help="Auto-filter and write clean dataset to OUTPUT path", |
| ) |
| parser.add_argument( |
| "--min-length", |
| type=int, |
| default=50, |
| help="Minimum assistant response length in words (default: 50)", |
| ) |
| parser.add_argument( |
| "--max-length", |
| type=int, |
| default=10000, |
| help="Maximum assistant response length in words (default: 10000)", |
| ) |
| parser.add_argument( |
| "--duplicate-threshold", |
| type=float, |
| default=0.85, |
| help="Jaccard similarity threshold for near-duplicates (default: 0.85)", |
| ) |
| parser.add_argument( |
| "--json-report", |
| metavar="PATH", |
| default=None, |
| help="Save report as JSON to this path", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| validator = DatasetValidator( |
| min_response_length=args.min_length, |
| max_response_length=args.max_length, |
| near_duplicate_threshold=args.duplicate_threshold, |
| ) |
|
|
| print(f"Validating: {args.dataset}\n") |
| report = validator.validate(args.dataset) |
| print(validator.format_report(report)) |
|
|
| if args.json_report: |
| |
| save_report = {k: v for k, v in report.items() |
| if k not in ("issues", "line_validity", "valid_entries")} |
| save_report["issue_count"] = len(report["issues"]) |
| save_report["issues_summary"] = [repr(i) for i in report["issues"][:50]] |
| os.makedirs(os.path.dirname(args.json_report) or ".", exist_ok=True) |
| with open(args.json_report, "w", encoding="utf-8") as f: |
| json.dump(save_report, f, indent=2, default=str) |
| print(f"\nJSON report saved to: {args.json_report}") |
|
|
| if args.filter: |
| print(f"\nFiltering dataset -> {args.filter}") |
| filter_stats = validator.filter_dataset(args.dataset, args.filter) |
| print(f" Input lines: {filter_stats['input_lines']}") |
| print(f" Kept: {filter_stats['kept']}") |
| print(f" Removed: {filter_stats['removed']}") |
| for reason, count in filter_stats["removal_reasons"].items(): |
| print(f" - {reason}: {count}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|