| | """ |
| | Data Preprocessing for Memory Routing Training |
| | |
| | This script converts synthetic JSONL conversations to Tinker-compatible |
| | types.Datum objects for supervised fine-tuning. |
| | |
| | Per Tinker docs (rendering.mdx): |
| | - Use renderer.build_supervised_example() to get tokens and weights |
| | - Weights indicate which tokens to train on (1.0 for completion, 0.0 for prompt) |
| | - Target tokens are shifted by 1 (predicting next token) |
| | |
| | Per PRD Section 6.6: |
| | - Validate datum length <= 4096 |
| | - Ensure non-zero weights |
| | - Verify token IDs are within vocab range |
| | """ |
| |
|
| | import json |
| | import os |
| | from typing import List, Dict, Any, Tuple |
| | from dataclasses import dataclass |
| |
|
| | |
| | |
| | |
| |
|
| | MODEL_NAME = "meta-llama/Llama-3.1-8B" |
| | RENDERER_NAME = "llama3" |
| | MAX_SEQUENCE_LENGTH = 4096 |
| |
|
| | |
| | VALID_CATEGORIES = { |
| | "company.brand_core", |
| | "company.strategic_signatures", |
| | "company.knowledge_artifacts", |
| | "company.business_priorities", |
| | "company.tools_config", |
| | "company.performance_context", |
| | "user.communication_style", |
| | "user.strategic_approach", |
| | "user.role_context", |
| | "user.workflow_patterns", |
| | "user.session_history", |
| | "user.interaction_preferences", |
| | "none" |
| | } |
| |
|
| | @dataclass |
| | class PreprocessingStats: |
| | total_examples: int = 0 |
| | valid_examples: int = 0 |
| | skipped_too_long: int = 0 |
| | skipped_zero_weights: int = 0 |
| | skipped_invalid_tokens: int = 0 |
| | skipped_invalid_categories: int = 0 |
| |
|
| |
|
| | def build_routing_prompt(conversation: List[Dict[str, str]], categories: List[str]) -> List[Dict[str, str]]: |
| | """ |
| | Build the full conversation for training, including: |
| | 1. System prompt with taxonomy |
| | 2. User message with conversation |
| | 3. Assistant response with categories |
| | |
| | Per PRD Section 6 - Student Prompt format. |
| | """ |
| | |
| | system_content = """You route marketing conversations into structured memory categories. |
| | |
| | Available categories: |
| | - company.brand_core: Voice, values, positioning, identity anchors (Long >1y) |
| | - company.strategic_signatures: Decision frameworks, strategic heuristics (Long >1y) |
| | - company.knowledge_artifacts: Docs, style guides, playbooks (Long >1y) |
| | - company.business_priorities: Quarterly/seasonal goals, active campaigns (Short <3m) |
| | - company.tools_config: Integrations, API keys, workflow settings (Medium ~6m) |
| | - company.performance_context: Campaign metrics, retrospectives, learnings (Rolling ~6m) |
| | - user.communication_style: Tone, verbosity, format expectations (Long >1y) |
| | - user.strategic_approach: Personal priorities, success definitions (Long >1y) |
| | - user.role_context: Title, scope, decision authority (Medium ~1y) |
| | - user.workflow_patterns: Review cadence, collaboration norms (Medium ~1y) |
| | - user.session_history: Immediate context, recent asks (Short <2w) |
| | - user.interaction_preferences: Coaching style, feedback expectations (Evolving) |
| | - none: Irrelevant, vague, or transactional content |
| | |
| | Respond with comma-separated categories. Use 'none' only if no other category applies.""" |
| |
|
| | |
| | conversation_text = "" |
| | for turn in conversation: |
| | |
| | if isinstance(turn, str): |
| | conversation_text += f"UNKNOWN: {turn}\n" |
| | continue |
| | if not isinstance(turn, dict): |
| | continue |
| | role = turn.get("role", "unknown") |
| | content = turn.get("content", "") |
| | conversation_text += f"{role.upper()}: {content}\n" |
| | |
| | user_content = f"Conversation:\n{conversation_text.strip()}\n\nWhat memory categories apply?" |
| | |
| | |
| | assistant_content = ", ".join(categories) |
| | |
| | return [ |
| | {"role": "system", "content": system_content}, |
| | {"role": "user", "content": user_content}, |
| | {"role": "assistant", "content": assistant_content} |
| | ] |
| |
|
| |
|
| | def load_synthetic_data(filepath: str) -> List[Dict[str, Any]]: |
| | """Load synthetic data from JSONL file.""" |
| | data = [] |
| | with open(filepath, "r") as f: |
| | for line in f: |
| | if line.strip(): |
| | item = json.loads(line) |
| | data.append(item) |
| | return data |
| |
|
| |
|
| | def validate_categories(categories: List[str]) -> bool: |
| | """Validate that all categories are in the taxonomy.""" |
| | return all(cat in VALID_CATEGORIES for cat in categories) |
| |
|
| |
|
| | def preprocess_example_mock(example: Dict[str, Any], stats: PreprocessingStats) -> Dict[str, Any] | None: |
| | """ |
| | Mock preprocessing that validates structure without Tinker. |
| | Returns a dict representation of what would become a Datum. |
| | |
| | Use this for testing without Tinker installed. |
| | """ |
| | conversation = example.get("conversation", []) |
| | labels = example.get("labels", {}) |
| | categories = labels.get("categories", []) |
| | |
| | |
| | if not validate_categories(categories): |
| | stats.skipped_invalid_categories += 1 |
| | return None |
| | |
| | |
| | training_messages = build_routing_prompt(conversation, categories) |
| | |
| | |
| | total_chars = sum(len(m["content"]) for m in training_messages) |
| | estimated_tokens = total_chars // 4 |
| | |
| | if estimated_tokens > MAX_SEQUENCE_LENGTH: |
| | stats.skipped_too_long += 1 |
| | return None |
| | |
| | stats.valid_examples += 1 |
| | |
| | return { |
| | "messages": training_messages, |
| | "categories": categories, |
| | "estimated_tokens": estimated_tokens, |
| | "scenario_id": example.get("scenario_id", "unknown") |
| | } |
| |
|
| |
|
| | def preprocess_with_tinker(example: Dict[str, Any], renderer, tokenizer, vocab_size: int, stats: PreprocessingStats): |
| | """ |
| | Full preprocessing with Tinker renderer. |
| | |
| | Per Tinker docs (rendering.mdx): |
| | - build_supervised_example returns (tokens, weights) |
| | - weights=1.0 for completion tokens, weights=0.0 for prompt tokens |
| | |
| | Per Tinker docs (training-sampling.mdx): |
| | - input_tokens = tokens[:-1] |
| | - target_tokens = tokens[1:] # Shifted for next-token prediction |
| | - weights = weights[1:] |
| | """ |
| | from tinker import types |
| | |
| | conversation = example.get("conversation", []) |
| | labels = example.get("labels", {}) |
| | categories = labels.get("categories", []) |
| | |
| | |
| | if not validate_categories(categories): |
| | stats.skipped_invalid_categories += 1 |
| | return None |
| | |
| | |
| | training_messages = build_routing_prompt(conversation, categories) |
| | |
| | |
| | |
| | tokens, weights = renderer.build_supervised_example(training_messages) |
| | |
| | |
| | if len(tokens) > MAX_SEQUENCE_LENGTH: |
| | stats.skipped_too_long += 1 |
| | return None |
| | |
| | |
| | |
| | input_tokens = tokens[:-1] |
| | target_tokens = tokens[1:] |
| | loss_weights = weights[1:] |
| | |
| | |
| | if sum(loss_weights) == 0: |
| | stats.skipped_zero_weights += 1 |
| | return None |
| | |
| | |
| | if not all(0 <= t < vocab_size for t in target_tokens): |
| | stats.skipped_invalid_tokens += 1 |
| | return None |
| | |
| | |
| | |
| | datum = types.Datum( |
| | model_input=types.ModelInput.from_ints(input_tokens), |
| | loss_fn_inputs=dict( |
| | target_tokens=target_tokens, |
| | weights=loss_weights |
| | ) |
| | ) |
| | |
| | stats.valid_examples += 1 |
| | return datum |
| |
|
| |
|
| | def preprocess_dataset( |
| | input_path: str, |
| | output_dir: str, |
| | use_tinker: bool = False, |
| | train_split: float = 0.8 |
| | ) -> Tuple[PreprocessingStats, str, str]: |
| | """ |
| | Preprocess the full dataset. |
| | |
| | Args: |
| | input_path: Path to training_dataset_1000.jsonl |
| | output_dir: Directory to save processed data |
| | use_tinker: Whether to use actual Tinker (requires installation) |
| | train_split: Fraction for training (rest is test) |
| | |
| | Returns: |
| | stats, train_path, test_path |
| | """ |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | |
| | print(f"Loading data from {input_path}...") |
| | raw_data = load_synthetic_data(input_path) |
| | print(f"Loaded {len(raw_data)} examples") |
| | |
| | stats = PreprocessingStats(total_examples=len(raw_data)) |
| | |
| | if use_tinker: |
| | |
| | from tinker_cookbook import renderers, tokenizer_utils |
| | |
| | print(f"Initializing tokenizer for {MODEL_NAME}...") |
| | tokenizer = tokenizer_utils.get_tokenizer(MODEL_NAME) |
| | renderer = renderers.get_renderer(name=RENDERER_NAME, tokenizer=tokenizer) |
| | vocab_size = len(tokenizer) |
| | print(f"Vocab size: {vocab_size}") |
| | |
| | processed_data = [] |
| | for i, example in enumerate(raw_data): |
| | if i % 100 == 0: |
| | print(f"Processing {i}/{len(raw_data)}...") |
| | datum = preprocess_with_tinker(example, renderer, tokenizer, vocab_size, stats) |
| | if datum is not None: |
| | processed_data.append(datum) |
| | else: |
| | |
| | print("Running mock preprocessing (no Tinker)...") |
| | processed_data = [] |
| | for i, example in enumerate(raw_data): |
| | if i % 100 == 0: |
| | print(f"Processing {i}/{len(raw_data)}...") |
| | result = preprocess_example_mock(example, stats) |
| | if result is not None: |
| | processed_data.append(result) |
| | |
| | |
| | split_idx = int(len(processed_data) * train_split) |
| | train_data = processed_data[:split_idx] |
| | test_data = processed_data[split_idx:] |
| | |
| | |
| | train_path = os.path.join(output_dir, "train_data.json") |
| | test_path = os.path.join(output_dir, "test_data.json") |
| | |
| | with open(train_path, "w") as f: |
| | json.dump([d if isinstance(d, dict) else d.model_dump() for d in train_data], f) |
| | |
| | with open(test_path, "w") as f: |
| | json.dump([d if isinstance(d, dict) else d.model_dump() for d in test_data], f) |
| | |
| | print(f"\n=== Preprocessing Complete ===") |
| | print(f"Total examples: {stats.total_examples}") |
| | print(f"Valid examples: {stats.valid_examples}") |
| | print(f"Skipped (too long): {stats.skipped_too_long}") |
| | print(f"Skipped (zero weights): {stats.skipped_zero_weights}") |
| | print(f"Skipped (invalid tokens): {stats.skipped_invalid_tokens}") |
| | print(f"Skipped (invalid categories): {stats.skipped_invalid_categories}") |
| | print(f"\nTrain set: {len(train_data)} examples") |
| | print(f"Test set: {len(test_data)} examples") |
| | print(f"\nSaved to:") |
| | print(f" Train: {train_path}") |
| | print(f" Test: {test_path}") |
| | |
| | return stats, train_path, test_path |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import sys |
| | |
| | input_path = sys.argv[1] if len(sys.argv) > 1 else "synthetic_data/training_dataset_1000.jsonl" |
| | output_dir = sys.argv[2] if len(sys.argv) > 2 else "training/processed_data" |
| | use_tinker = "--tinker" in sys.argv |
| | |
| | preprocess_dataset(input_path, output_dir, use_tinker=use_tinker) |
| |
|
| |
|