| """ |
| Full Training Pipeline: SFT -> RL -> Evaluation |
| |
| This script implements the complete training pipeline for the Memory Routing Agent |
| following best practices from Tinker documentation and ML research. |
| |
| Key insights from the codebase analysis: |
| 1. SFT must save with save_state() for RL to continue from those weights |
| 2. RL uses importance_sampling loss with proper advantage normalization |
| 3. Evaluation should compare against baseline (untrained) and larger models |
| |
| Architecture decisions: |
| - Base model: Llama-3.1-8B (good balance of capability and efficiency) |
| - LoRA rank 32 (sufficient for classification, per Tinker docs) |
| - SFT: 100 steps with early stopping, then RL: 15 iterations |
| """ |
|
|
| import asyncio |
| import json |
| import time |
| import os |
| import numpy as np |
| from typing import List, Dict, Any, Tuple, Optional |
| from dataclasses import dataclass, field |
| from collections import Counter |
| from datetime import datetime |
|
|
|
|
| @dataclass |
| class PipelineConfig: |
| """Configuration for the full training pipeline.""" |
| |
| |
| base_model: str = "meta-llama/Llama-3.1-8B" |
| lora_rank: int = 32 |
| renderer_name: str = "llama3" |
| |
| |
| sft_steps: int = 100 |
| sft_batch_size: int = 64 |
| sft_lr: Optional[float] = None |
| sft_eval_every: int = 10 |
| sft_early_stopping_patience: int = 5 |
| |
| |
| rl_iterations: int = 15 |
| rl_batch_size: int = 32 |
| rl_group_size: int = 8 |
| rl_lr: float = 1e-5 |
| rl_temperature: float = 0.7 |
| rl_kl_threshold: float = 0.01 |
| |
| |
| train_data_path: str = "training/processed_data/train_data.json" |
| test_data_path: str = "training/processed_data/test_data.json" |
| |
| |
| experiment_name: str = "memory_routing_v1" |
| output_dir: str = "training/experiments" |
|
|
|
|
| |
| VALID_CATEGORIES = { |
| "company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts", |
| "company.business_priorities", "company.tools_config", "company.performance_context", |
| "user.communication_style", "user.strategic_approach", "user.role_context", |
| "user.workflow_patterns", "user.session_history", "user.interaction_preferences", |
| "none" |
| } |
|
|
| CATEGORY_PERSISTENCE = { |
| "company.brand_core": "long", "company.strategic_signatures": "long", |
| "company.knowledge_artifacts": "long", "company.business_priorities": "short", |
| "company.tools_config": "medium", "company.performance_context": "rolling", |
| "user.communication_style": "long", "user.strategic_approach": "long", |
| "user.role_context": "medium", "user.workflow_patterns": "medium", |
| "user.session_history": "short", "user.interaction_preferences": "evolving", |
| "none": "short" |
| } |
|
|
|
|
| def compute_reward(predicted_text: str, gold_categories: List[str]) -> Tuple[float, Dict]: |
| """ |
| Compute reward with detailed breakdown. |
| |
| R_total = 0.6 * R_F1 + 0.2 * R_temp + 0.1 * R_parity + 0.1 * R_eff |
| """ |
| info = {"format_valid": True, "r_f1": 0, "r_temp": 0, "r_parity": 0, "r_eff": 0} |
| |
| |
| if not predicted_text or not predicted_text.strip(): |
| info["format_valid"] = False |
| return -1.0, info |
| |
| predicted = set([c.strip().lower() for c in predicted_text.split(",") |
| if c.strip().lower() in VALID_CATEGORIES]) |
| |
| if not predicted: |
| info["format_valid"] = False |
| return -1.0, info |
| |
| |
| if "none" in predicted and len(predicted) > 1: |
| predicted.discard("none") |
| |
| gold = set([c.lower() for c in gold_categories]) |
| |
| |
| if predicted and gold: |
| tp = len(predicted & gold) |
| precision = tp / len(predicted) |
| recall = tp / len(gold) |
| info["r_f1"] = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 |
| elif not predicted and not gold: |
| info["r_f1"] = 1.0 |
| |
| |
| def majority_persistence(cats): |
| if not cats: |
| return "medium" |
| persis = [CATEGORY_PERSISTENCE.get(c, "medium") for c in cats] |
| return Counter(persis).most_common(1)[0][0] |
| |
| pred_pers = majority_persistence(predicted) |
| gold_pers = majority_persistence(gold) |
| |
| if pred_pers == gold_pers: |
| info["r_temp"] = 1.0 |
| elif (pred_pers, gold_pers) in [("long", "medium"), ("medium", "long"), |
| ("medium", "short"), ("short", "medium")]: |
| info["r_temp"] = 0.5 |
| |
| |
| def get_scope(cats): |
| scopes = set() |
| for c in cats: |
| if c.startswith("company."): |
| scopes.add("company") |
| elif c.startswith("user."): |
| scopes.add("user") |
| if len(scopes) == 2: |
| return "mixed" |
| return scopes.pop() if scopes else "none" |
| |
| if get_scope(predicted) == get_scope(gold): |
| info["r_parity"] = 1.0 |
| |
| |
| n = len(predicted) |
| info["r_eff"] = 1.0 if n <= 3 else (0.7 if n == 4 else (0.4 if n == 5 else 0.0)) |
| |
| |
| r_total = 0.6 * info["r_f1"] + 0.2 * info["r_temp"] + 0.1 * info["r_parity"] + 0.1 * info["r_eff"] |
| |
| return r_total, info |
|
|
|
|
| async def run_sft_phase(config: PipelineConfig, service_client, tokenizer, renderer): |
| """ |
| Phase 1: Supervised Fine-Tuning |
| |
| Key principles: |
| - Use cross_entropy loss for next-token prediction |
| - Monitor train/test loss for overfitting |
| - Save full state checkpoint for RL continuation |
| """ |
| import tinker |
| from tinker import types |
| from tinker_cookbook.hyperparam_utils import get_lr |
| |
| print("\n" + "=" * 70) |
| print("PHASE 1: SUPERVISED FINE-TUNING") |
| print("=" * 70) |
| |
| |
| with open(config.train_data_path, "r") as f: |
| train_data_raw = json.load(f) |
| with open(config.test_data_path, "r") as f: |
| test_data_raw = json.load(f) |
| |
| print(f"Train: {len(train_data_raw)}, Test: {len(test_data_raw)}") |
| |
| |
| lr = config.sft_lr or get_lr(config.base_model) |
| print(f"Learning rate: {lr:.2e}") |
| |
| |
| training_client = await service_client.create_lora_training_client_async( |
| base_model=config.base_model, |
| rank=config.lora_rank, |
| ) |
| |
| |
| def to_datum(item): |
| messages = item.get("messages", []) |
| tokens, weights = renderer.build_supervised_example(messages) |
| if hasattr(tokens, 'tolist'): |
| tokens = tokens.tolist() |
| if hasattr(weights, 'tolist'): |
| weights = weights.tolist() |
| |
| return types.Datum( |
| model_input=types.ModelInput.from_ints(tokens[:-1]), |
| loss_fn_inputs=dict(target_tokens=tokens[1:], weights=weights[1:]) |
| ) |
| |
| train_data = [to_datum(item) for item in train_data_raw] |
| test_data = [to_datum(item) for item in test_data_raw[:50]] |
| |
| |
| metrics_log = [] |
| best_test_loss = float('inf') |
| no_improvement = 0 |
| |
| for step in range(config.sft_steps): |
| step_start = time.time() |
| |
| |
| batch_idx = (step * config.sft_batch_size) % len(train_data) |
| batch = train_data[batch_idx:batch_idx + config.sft_batch_size] |
| if len(batch) < config.sft_batch_size: |
| batch = batch + train_data[:config.sft_batch_size - len(batch)] |
| |
| |
| fwd_future = await training_client.forward_backward_async(batch, loss_fn="cross_entropy") |
| optim_future = await training_client.optim_step_async( |
| types.AdamParams(learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8) |
| ) |
| |
| fwd_result = await fwd_future.result_async() |
| await optim_future.result_async() |
| |
| |
| logprobs = np.concatenate([o['logprobs'].tolist() for o in fwd_result.loss_fn_outputs]) |
| weights = np.concatenate([d.loss_fn_inputs['weights'].tolist() for d in batch]) |
| train_loss = -np.dot(logprobs, weights) / max(weights.sum(), 1) |
| |
| step_time = time.time() - step_start |
| |
| |
| test_loss = None |
| if step % config.sft_eval_every == 0 or step == config.sft_steps - 1: |
| eval_future = await training_client.forward_backward_async(test_data, loss_fn="cross_entropy") |
| eval_result = await eval_future.result_async() |
| test_logprobs = np.concatenate([o['logprobs'].tolist() for o in eval_result.loss_fn_outputs]) |
| test_weights = np.concatenate([d.loss_fn_inputs['weights'].tolist() for d in test_data]) |
| test_loss = -np.dot(test_logprobs, test_weights) / max(test_weights.sum(), 1) |
| |
| |
| if test_loss < best_test_loss: |
| best_test_loss = test_loss |
| no_improvement = 0 |
| else: |
| no_improvement += 1 |
| |
| print(f"Step {step:3d}: train_loss={train_loss:.4f}, test_loss={test_loss:.4f}, time={step_time:.1f}s") |
| |
| if no_improvement >= config.sft_early_stopping_patience: |
| print(f"Early stopping at step {step}") |
| break |
| else: |
| print(f"Step {step:3d}: train_loss={train_loss:.4f}, time={step_time:.1f}s") |
| |
| metrics_log.append({ |
| "step": step, "train_loss": float(train_loss), |
| "test_loss": float(test_loss) if test_loss else None, |
| "time": step_time |
| }) |
| |
| |
| print("\nSaving final SFT checkpoint...") |
| state_future = await training_client.save_state_async(name="sft_final") |
| state_result = await state_future.result_async() |
| sft_checkpoint = state_result.path |
| |
| |
| sampler_future = await training_client.save_weights_for_sampler_async(name="sft_final_sampler") |
| sampler_result = await sampler_future.result_async() |
| sampler_checkpoint = sampler_result.path |
| |
| print(f"SFT State checkpoint: {sft_checkpoint}") |
| print(f"SFT Sampler checkpoint: {sampler_checkpoint}") |
| |
| return training_client, sft_checkpoint, sampler_checkpoint, metrics_log |
|
|
|
|
| async def run_rl_phase(config: PipelineConfig, service_client, training_client, |
| sft_checkpoint: str, tokenizer, renderer): |
| """ |
| Phase 2: Reinforcement Learning |
| |
| Key principles: |
| - Load SFT weights to continue training |
| - Use importance_sampling loss for policy gradient |
| - Group rollouts for variance reduction |
| - Monitor KL divergence for stability |
| """ |
| import tinker |
| from tinker import types |
| |
| print("\n" + "=" * 70) |
| print("PHASE 2: REINFORCEMENT LEARNING") |
| print("=" * 70) |
| |
| |
| with open(config.train_data_path, "r") as f: |
| train_data = json.load(f) |
| |
| print(f"Training examples: {len(train_data)}") |
| print(f"RL iterations: {config.rl_iterations}") |
| print(f"Batch size: {config.rl_batch_size}, Group size: {config.rl_group_size}") |
| |
| |
| print(f"\nLoading SFT checkpoint: {sft_checkpoint}") |
| await training_client.load_state_async(sft_checkpoint) |
| |
| stop_sequences = renderer.get_stop_sequences() |
| metrics_log = [] |
| |
| for iteration in range(config.rl_iterations): |
| iter_start = time.time() |
| print(f"\n--- Iteration {iteration + 1}/{config.rl_iterations} ---") |
| |
| |
| save_future = await training_client.save_weights_for_sampler_async( |
| name=f"rl_iter_{iteration:03d}" |
| ) |
| save_result = await save_future.result_async() |
| sampling_client = service_client.create_sampling_client(model_path=save_result.path) |
| |
| |
| batch_indices = np.random.choice(len(train_data), size=config.rl_batch_size, replace=False) |
| |
| all_rollouts = [] |
| all_rewards = [] |
| reward_infos = [] |
| |
| for idx in batch_indices: |
| example = train_data[idx] |
| gold_categories = example.get("categories", []) |
| messages = example.get("messages", []) |
| prompt_messages = [m for m in messages if m.get("role") != "assistant"] |
| |
| if not prompt_messages: |
| continue |
| |
| prompt = renderer.build_generation_prompt(prompt_messages) |
| params = types.SamplingParams( |
| max_tokens=50, temperature=config.rl_temperature, stop=stop_sequences |
| ) |
| |
| result = sampling_client.sample( |
| prompt=prompt, sampling_params=params, num_samples=config.rl_group_size |
| ).result() |
| |
| for seq in result.sequences: |
| response, success = renderer.parse_response(seq.tokens) |
| predicted = response["content"] if success else "" |
| reward, info = compute_reward(predicted, gold_categories) |
| |
| all_rollouts.append({ |
| "prompt": prompt, |
| "tokens": seq.tokens, |
| "logprobs": seq.logprobs or [], |
| "predicted": predicted, |
| "gold": gold_categories |
| }) |
| all_rewards.append(reward) |
| reward_infos.append(info) |
| |
| |
| rewards_arr = np.array(all_rewards) |
| mean_reward = rewards_arr.mean() |
| std_reward = rewards_arr.std() + 1e-8 |
| advantages = (rewards_arr - mean_reward) / std_reward |
| |
| |
| training_data = [] |
| for i, rollout in enumerate(all_rollouts): |
| if not rollout["logprobs"]: |
| continue |
| |
| prompt_tokens = rollout["prompt"].to_ints() |
| gen_tokens = rollout["tokens"] |
| logprobs = rollout["logprobs"] |
| adv = advantages[i] |
| |
| n_prompt = len(prompt_tokens) - 1 |
| n_gen = len(gen_tokens) |
| |
| if len(logprobs) != n_gen: |
| continue |
| |
| full_input = prompt_tokens + gen_tokens[:-1] if n_gen > 1 else prompt_tokens |
| full_target = prompt_tokens[1:] + gen_tokens |
| full_logprobs = [0.0] * n_prompt + logprobs |
| full_advantages = [0.0] * n_prompt + [adv] * n_gen |
| |
| if len(full_target) != len(full_input) or len(full_logprobs) != len(full_input): |
| continue |
| |
| training_data.append(types.Datum( |
| model_input=types.ModelInput.from_ints(full_input), |
| loss_fn_inputs=dict( |
| target_tokens=full_target, |
| logprobs=full_logprobs, |
| advantages=full_advantages |
| ) |
| )) |
| |
| |
| if training_data: |
| fwd_future = await training_client.forward_backward_async( |
| training_data, loss_fn="importance_sampling" |
| ) |
| optim_future = await training_client.optim_step_async( |
| types.AdamParams(learning_rate=config.rl_lr, beta1=0.9, beta2=0.95, eps=1e-8) |
| ) |
| await fwd_future.result_async() |
| await optim_future.result_async() |
| |
| |
| iter_time = time.time() - iter_start |
| accuracy = sum(1 for r in all_rewards if r > 0) / len(all_rewards) if all_rewards else 0 |
| format_valid_rate = sum(1 for info in reward_infos if info["format_valid"]) / len(reward_infos) |
| |
| metrics = { |
| "iteration": iteration, |
| "mean_reward": float(mean_reward), |
| "std_reward": float(std_reward), |
| "accuracy": accuracy, |
| "format_valid_rate": format_valid_rate, |
| "num_rollouts": len(all_rollouts), |
| "time": iter_time |
| } |
| metrics_log.append(metrics) |
| |
| print(f" Reward: {mean_reward:.3f} ± {std_reward:.3f}, Acc: {accuracy:.1%}, Format: {format_valid_rate:.1%}") |
| |
| |
| print("\nSaving final RL checkpoint...") |
| final_future = await training_client.save_weights_for_sampler_async(name="rl_final") |
| final_result = await final_future.result_async() |
| rl_checkpoint = final_result.path |
| |
| print(f"RL checkpoint: {rl_checkpoint}") |
| |
| return rl_checkpoint, metrics_log |
|
|
|
|
| async def run_evaluation(config: PipelineConfig, service_client, checkpoint: str, |
| tokenizer, renderer, name: str = "model"): |
| """ |
| Comprehensive evaluation on test set. |
| """ |
| import tinker |
| from tinker import types |
| |
| print(f"\n--- Evaluating: {name} ---") |
| |
| |
| with open(config.test_data_path, "r") as f: |
| test_data = json.load(f) |
| |
| sampling_client = service_client.create_sampling_client(model_path=checkpoint) |
| stop_sequences = renderer.get_stop_sequences() |
| |
| results = [] |
| |
| for i, example in enumerate(test_data): |
| gold = example.get("categories", []) |
| messages = example.get("messages", []) |
| prompt_messages = [m for m in messages if m.get("role") != "assistant"] |
| |
| if not prompt_messages: |
| continue |
| |
| prompt = renderer.build_generation_prompt(prompt_messages) |
| params = types.SamplingParams(max_tokens=50, temperature=0.1, stop=stop_sequences) |
| |
| result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result() |
| response, success = renderer.parse_response(result.sequences[0].tokens) |
| predicted = response["content"] if success else "" |
| |
| reward, info = compute_reward(predicted, gold) |
| |
| predicted_set = set([c.strip().lower() for c in predicted.split(",") if c.strip().lower() in VALID_CATEGORIES]) |
| gold_set = set([c.lower() for c in gold]) |
| |
| results.append({ |
| "gold": gold, |
| "predicted": predicted, |
| "reward": reward, |
| "exact_match": predicted_set == gold_set, |
| "any_match": len(predicted_set & gold_set) > 0, |
| "precision": len(predicted_set & gold_set) / len(predicted_set) if predicted_set else 0, |
| "recall": len(predicted_set & gold_set) / len(gold_set) if gold_set else 0, |
| "format_valid": info["format_valid"] |
| }) |
| |
| if (i + 1) % 50 == 0: |
| print(f" Evaluated {i + 1}/{len(test_data)}") |
| |
| |
| n = len(results) |
| metrics = { |
| "name": name, |
| "n_examples": n, |
| "mean_reward": np.mean([r["reward"] for r in results]), |
| "exact_match": np.mean([r["exact_match"] for r in results]), |
| "any_match": np.mean([r["any_match"] for r in results]), |
| "precision": np.mean([r["precision"] for r in results]), |
| "recall": np.mean([r["recall"] for r in results]), |
| "format_valid": np.mean([r["format_valid"] for r in results]), |
| } |
| metrics["f1"] = 2 * metrics["precision"] * metrics["recall"] / (metrics["precision"] + metrics["recall"]) if (metrics["precision"] + metrics["recall"]) > 0 else 0 |
| |
| print(f" Any Match: {metrics['any_match']:.1%}") |
| print(f" Exact Match: {metrics['exact_match']:.1%}") |
| print(f" F1: {metrics['f1']:.1%}") |
| print(f" Mean Reward: {metrics['mean_reward']:.3f}") |
| |
| return metrics, results |
|
|
|
|
| async def main(): |
| """Run the full training pipeline.""" |
| import tinker |
| from tinker_cookbook import renderers |
| from tinker_cookbook.tokenizer_utils import get_tokenizer |
| from dotenv import load_dotenv |
| |
| load_dotenv() |
| |
| config = PipelineConfig() |
| |
| |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| exp_dir = os.path.join(config.output_dir, f"{config.experiment_name}_{timestamp}") |
| os.makedirs(exp_dir, exist_ok=True) |
| |
| print("=" * 70) |
| print("MEMORY ROUTING AGENT - FULL TRAINING PIPELINE") |
| print("=" * 70) |
| print(f"Experiment: {config.experiment_name}") |
| print(f"Output: {exp_dir}") |
| print(f"Base model: {config.base_model}") |
| print(f"LoRA rank: {config.lora_rank}") |
| |
| |
| service_client = tinker.ServiceClient() |
| tokenizer = get_tokenizer(config.base_model) |
| renderer = renderers.get_renderer(name=config.renderer_name, tokenizer=tokenizer) |
| |
| |
| training_client, sft_state_ckpt, sft_sampler_ckpt, sft_metrics = await run_sft_phase( |
| config, service_client, tokenizer, renderer |
| ) |
| |
| |
| sft_eval, _ = await run_evaluation( |
| config, service_client, sft_sampler_ckpt, tokenizer, renderer, "SFT Model" |
| ) |
| |
| |
| rl_checkpoint, rl_metrics = await run_rl_phase( |
| config, service_client, training_client, sft_state_ckpt, tokenizer, renderer |
| ) |
| |
| |
| rl_eval, _ = await run_evaluation( |
| config, service_client, rl_checkpoint, tokenizer, renderer, "RL Model" |
| ) |
| |
| |
| results = { |
| "config": { |
| "base_model": config.base_model, |
| "lora_rank": config.lora_rank, |
| "sft_steps": config.sft_steps, |
| "rl_iterations": config.rl_iterations, |
| }, |
| "checkpoints": { |
| "sft_state": sft_state_ckpt, |
| "sft_sampler": sft_sampler_ckpt, |
| "rl_final": rl_checkpoint, |
| }, |
| "sft_metrics": sft_metrics, |
| "rl_metrics": rl_metrics, |
| "evaluation": { |
| "sft": sft_eval, |
| "rl": rl_eval, |
| } |
| } |
| |
| results_path = os.path.join(exp_dir, "results.json") |
| with open(results_path, "w") as f: |
| json.dump(results, f, indent=2, default=str) |
| |
| print("\n" + "=" * 70) |
| print("TRAINING COMPLETE") |
| print("=" * 70) |
| print(f"Results saved to: {results_path}") |
| print(f"\nFinal Model: {rl_checkpoint}") |
| print(f"\nComparison:") |
| print(f" SFT - F1: {sft_eval['f1']:.1%}, Any Match: {sft_eval['any_match']:.1%}") |
| print(f" RL - F1: {rl_eval['f1']:.1%}, Any Match: {rl_eval['any_match']:.1%}") |
| |
| return results |
|
|
|
|
| if __name__ == "__main__": |
| asyncio.run(main()) |
|
|
|
|