| """ |
| Dataset Generator for Codette LoRA Training |
| ============================================= |
| |
| Main orchestrator that combines TemplateRegistry and AnswerGenerator |
| to produce chat-format JSONL files for fine-tuning Llama 3.1 8B |
| with LoRA adapters. |
| |
| Features: |
| - Deduplication: tracks all generated prompts to prevent duplicates |
| - Reproducible: seed-based RNG for deterministic output |
| - CLI interface: generate for one adapter or all adapters |
| - Progress reporting: logs generation progress |
| - Validation: checks output format before writing |
| |
| Usage: |
| python -m dataset_engine.dataset_generator --adapter newton --count 3000 |
| python -m dataset_engine.dataset_generator --all |
| python -m dataset_engine.dataset_generator --adapter philosophy --count 2000 --seed 42 |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| import os |
| import sys |
| import time |
| from pathlib import Path |
| from typing import Optional, Set |
|
|
| from dataset_engine.template_registry import TemplateRegistry |
| from dataset_engine.answer_generator import AnswerGenerator |
|
|
| logger = logging.getLogger("dataset_generator") |
|
|
|
|
| class DatasetGenerator: |
| """Generates JSONL training datasets for Codette LoRA adapters.""" |
|
|
| def __init__(self, output_dir: str = "datasets", seed: Optional[int] = None): |
| """Initialize the generator. |
| |
| Args: |
| output_dir: Directory for output JSONL files. |
| seed: Random seed for reproducibility. None for non-deterministic. |
| """ |
| self.output_dir = Path(output_dir) |
| self.output_dir.mkdir(parents=True, exist_ok=True) |
| self.seed = seed |
| self.registry = TemplateRegistry(seed=seed) |
| self.answer_gen = AnswerGenerator(seed=seed) |
| self._seen_questions: Set[str] = set() |
| self._stats = { |
| "total_generated": 0, |
| "duplicates_skipped": 0, |
| "counterexamples": 0, |
| } |
|
|
| def reset_dedup(self): |
| """Clear the deduplication set (use between adapters).""" |
| self._seen_questions.clear() |
|
|
| def reset_stats(self): |
| """Reset generation statistics.""" |
| self._stats = { |
| "total_generated": 0, |
| "duplicates_skipped": 0, |
| "counterexamples": 0, |
| } |
|
|
| def generate_adapter(self, adapter: str, |
| count: Optional[int] = None) -> str: |
| """Generate a JSONL dataset for a single adapter. |
| |
| Args: |
| adapter: Adapter name (e.g. 'newton', 'philosophy'). |
| count: Number of examples to generate. Defaults to the |
| adapter's target size from the registry. |
| |
| Returns: |
| Path to the generated JSONL file. |
| """ |
| if adapter not in self.registry.get_adapter_names(): |
| raise ValueError( |
| f"Unknown adapter '{adapter}'. " |
| f"Available: {self.registry.get_adapter_names()}" |
| ) |
|
|
| target = count or self.registry.get_target(adapter) |
| output_path = self.output_dir / f"{adapter}_reasoning.jsonl" |
|
|
| self.reset_dedup() |
| self.reset_stats() |
|
|
| logger.info( |
| "Generating %d examples for adapter '%s' -> %s", |
| target, adapter, output_path, |
| ) |
|
|
| start_time = time.time() |
| examples = [] |
| max_attempts = target * 5 |
| attempts = 0 |
|
|
| while len(examples) < target and attempts < max_attempts: |
| attempts += 1 |
| question, topic, subtopic, qtype = self.registry.sample_question(adapter) |
|
|
| |
| q_normalized = question.strip().lower() |
| if q_normalized in self._seen_questions: |
| self._stats["duplicates_skipped"] += 1 |
| continue |
| self._seen_questions.add(q_normalized) |
|
|
| |
| answer = self.answer_gen.generate( |
| adapter=adapter, |
| topic=topic, |
| subtopic=subtopic, |
| question=question, |
| question_type=qtype, |
| ) |
|
|
| |
| if not self._validate_answer(answer): |
| continue |
|
|
| |
| message = { |
| "messages": [ |
| { |
| "role": "system", |
| "content": self.registry.SYSTEM_PROMPT, |
| }, |
| { |
| "role": "user", |
| "content": question, |
| }, |
| { |
| "role": "assistant", |
| "content": answer, |
| }, |
| ] |
| } |
|
|
| examples.append(message) |
|
|
| if qtype == "counterexample": |
| self._stats["counterexamples"] += 1 |
|
|
| self._stats["total_generated"] = len(examples) |
|
|
| |
| if len(examples) > 0 and len(examples) % 500 == 0: |
| elapsed = time.time() - start_time |
| rate = len(examples) / elapsed if elapsed > 0 else 0 |
| logger.info( |
| " [%s] %d / %d examples (%.1f/sec, %d duplicates skipped)", |
| adapter, len(examples), target, rate, |
| self._stats["duplicates_skipped"], |
| ) |
|
|
| |
| with open(output_path, "w", encoding="utf-8") as f: |
| for example in examples: |
| f.write(json.dumps(example, ensure_ascii=False) + "\n") |
|
|
| elapsed = time.time() - start_time |
| counter_pct = ( |
| (self._stats["counterexamples"] / len(examples) * 100) |
| if examples else 0 |
| ) |
|
|
| logger.info( |
| "Completed '%s': %d examples in %.1fs " |
| "(%.1f%% counterexamples, %d duplicates skipped)", |
| adapter, len(examples), elapsed, counter_pct, |
| self._stats["duplicates_skipped"], |
| ) |
|
|
| if len(examples) < target: |
| logger.warning( |
| "Only generated %d / %d examples for '%s'. " |
| "Consider expanding template pools.", |
| len(examples), target, adapter, |
| ) |
|
|
| return str(output_path) |
|
|
| def generate_all(self) -> dict: |
| """Generate datasets for all adapters. |
| |
| Returns: |
| Dict mapping adapter names to output file paths. |
| """ |
| results = {} |
| total_start = time.time() |
|
|
| for adapter in self.registry.get_adapter_names(): |
| try: |
| path = self.generate_adapter(adapter) |
| results[adapter] = path |
| except Exception as e: |
| logger.error("Failed to generate '%s': %s", adapter, e) |
| results[adapter] = f"ERROR: {e}" |
|
|
| total_elapsed = time.time() - total_start |
| total_examples = sum( |
| self._count_lines(p) for p in results.values() |
| if not p.startswith("ERROR") |
| ) |
| logger.info( |
| "All adapters complete: %d total examples in %.1fs", |
| total_examples, total_elapsed, |
| ) |
| return results |
|
|
| @staticmethod |
| def _validate_answer(answer: str) -> bool: |
| """Check that an answer meets minimum quality standards.""" |
| if not answer or not answer.strip(): |
| return False |
| words = answer.split() |
| if len(words) < 40: |
| return False |
| |
| unique_words = set(w.lower() for w in words) |
| if len(unique_words) < 20: |
| return False |
| return True |
|
|
| @staticmethod |
| def _count_lines(filepath: str) -> int: |
| """Count lines in a file.""" |
| try: |
| with open(filepath, "r", encoding="utf-8") as f: |
| return sum(1 for _ in f) |
| except (OSError, IOError): |
| return 0 |
|
|
|
|
| def main(): |
| """CLI entry point.""" |
| parser = argparse.ArgumentParser( |
| description="Generate JSONL training datasets for Codette LoRA adapters.", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=( |
| "Examples:\n" |
| " python -m dataset_engine.dataset_generator --adapter newton --count 3000\n" |
| " python -m dataset_engine.dataset_generator --all\n" |
| " python -m dataset_engine.dataset_generator --all --seed 42\n" |
| " python -m dataset_engine.dataset_generator --adapter philosophy --output-dir ./my_datasets\n" |
| ), |
| ) |
|
|
| parser.add_argument( |
| "--adapter", |
| type=str, |
| help="Adapter name to generate for (e.g. newton, philosophy).", |
| ) |
| parser.add_argument( |
| "--all", |
| action="store_true", |
| help="Generate datasets for ALL adapters with their target sizes.", |
| ) |
| parser.add_argument( |
| "--count", |
| type=int, |
| default=None, |
| help="Number of examples to generate (overrides default target).", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=str, |
| default="datasets", |
| help="Output directory for JSONL files (default: datasets).", |
| ) |
| parser.add_argument( |
| "--seed", |
| type=int, |
| default=None, |
| help="Random seed for reproducible generation.", |
| ) |
| parser.add_argument( |
| "--verbose", |
| action="store_true", |
| help="Enable verbose logging.", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| log_level = logging.DEBUG if args.verbose else logging.INFO |
| logging.basicConfig( |
| level=log_level, |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
| datefmt="%Y-%m-%d %H:%M:%S", |
| ) |
|
|
| if not args.adapter and not args.all: |
| parser.error("Specify --adapter NAME or --all") |
|
|
| generator = DatasetGenerator( |
| output_dir=args.output_dir, |
| seed=args.seed, |
| ) |
|
|
| if args.all: |
| results = generator.generate_all() |
| print("\n--- Generation Summary ---") |
| for adapter, path in results.items(): |
| if path.startswith("ERROR"): |
| print(f" {adapter}: {path}") |
| else: |
| count = generator._count_lines(path) |
| print(f" {adapter}: {count} examples -> {path}") |
| else: |
| path = generator.generate_adapter(args.adapter, args.count) |
| count = generator._count_lines(path) |
| print(f"\nGenerated {count} examples -> {path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|