| """ |
| Training Monitor - Check progress and evaluate completed models. |
| """ |
|
|
| import asyncio |
| import json |
| import os |
| from datetime import datetime |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
| import tinker |
| from tinker import types |
| from tinker_cookbook import renderers |
| from tinker_cookbook.tokenizer_utils import get_tokenizer |
| import numpy as np |
|
|
| BASE_MODEL = "meta-llama/Llama-3.1-8B" |
|
|
| VALID_CATEGORIES = { |
| "company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts", |
| "company.business_priorities", "company.tools_config", "company.performance_context", |
| "user.communication_style", "user.strategic_approach", "user.role_context", |
| "user.workflow_patterns", "user.session_history", "user.interaction_preferences", |
| "none" |
| } |
|
|
|
|
| def list_training_runs(): |
| """List all training runs and their checkpoints.""" |
| service_client = tinker.ServiceClient() |
| rest_client = service_client.create_rest_client() |
| |
| runs = rest_client.list_training_runs().result() |
| |
| print("=" * 70) |
| print("TRAINING RUNS") |
| print("=" * 70) |
| |
| for run in runs.training_runs[:10]: |
| ckpts = rest_client.list_checkpoints(run.training_run_id).result() |
| |
| |
| sft_ckpts = [c for c in ckpts.checkpoints if 'sft' in c.checkpoint_id] |
| rl_ckpts = [c for c in ckpts.checkpoints if 'rl_' in c.checkpoint_id] |
| |
| print(f"\nRun: {run.training_run_id}") |
| print(f" Last request: {run.last_request_time}") |
| print(f" SFT checkpoints: {len(sft_ckpts)}") |
| print(f" RL checkpoints: {len(rl_ckpts)}") |
| |
| if rl_ckpts: |
| |
| latest = sorted(rl_ckpts, key=lambda x: x.time)[-1] |
| print(f" Latest RL: {latest.checkpoint_id}") |
| |
| |
| if 'final' in latest.checkpoint_id: |
| print(f" STATUS: RL COMPLETE") |
| print(f" Final checkpoint: tinker://{run.training_run_id}/{latest.checkpoint_id}") |
|
|
|
|
| async def quick_eval(checkpoint_path: str, n_samples: int = 20): |
| """Quick evaluation of a checkpoint.""" |
| service_client = tinker.ServiceClient() |
| tokenizer = get_tokenizer(BASE_MODEL) |
| renderer = renderers.get_renderer(name="llama3", tokenizer=tokenizer) |
| |
| |
| with open("training/processed_data/test_data.json", "r") as f: |
| test_data = json.load(f) |
| |
| print(f"\nEvaluating: {checkpoint_path}") |
| print(f"Samples: {n_samples}") |
| |
| sampling_client = service_client.create_sampling_client(model_path=checkpoint_path) |
| stop_sequences = renderer.get_stop_sequences() |
| |
| correct = 0 |
| total = 0 |
| |
| for example in test_data[:n_samples]: |
| gold = example.get("categories", []) |
| messages = example.get("messages", []) |
| prompt_messages = [m for m in messages if m.get("role") != "assistant"] |
| |
| if not prompt_messages: |
| continue |
| |
| prompt = renderer.build_generation_prompt(prompt_messages) |
| params = types.SamplingParams(max_tokens=50, temperature=0.1, stop=stop_sequences) |
| |
| result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result() |
| response, success = renderer.parse_response(result.sequences[0].tokens) |
| predicted_text = response["content"] if success else "" |
| |
| predicted_set = set([c.strip().lower() for c in predicted_text.split(",") |
| if c.strip().lower() in VALID_CATEGORIES]) |
| gold_set = set([c.lower() for c in gold]) |
| |
| if predicted_set & gold_set: |
| correct += 1 |
| total += 1 |
| |
| accuracy = correct / total if total > 0 else 0 |
| print(f"Any Match Accuracy: {accuracy:.1%} ({correct}/{total})") |
| |
| return accuracy |
|
|
|
|
| def find_best_checkpoint(): |
| """Find the best completed RL checkpoint.""" |
| service_client = tinker.ServiceClient() |
| rest_client = service_client.create_rest_client() |
| |
| runs = rest_client.list_training_runs().result() |
| |
| best_rl_checkpoint = None |
| best_sft_checkpoint = None |
| |
| for run in runs.training_runs: |
| ckpts = rest_client.list_checkpoints(run.training_run_id).result() |
| |
| for ckpt in ckpts.checkpoints: |
| if 'rl_final' in ckpt.checkpoint_id: |
| path = f"tinker://{run.training_run_id}/{ckpt.checkpoint_id}" |
| if best_rl_checkpoint is None or ckpt.time > best_rl_checkpoint[1]: |
| best_rl_checkpoint = (path, ckpt.time) |
| |
| if 'sft_final_sampler' in ckpt.checkpoint_id: |
| path = f"tinker://{run.training_run_id}/{ckpt.checkpoint_id}" |
| if best_sft_checkpoint is None or ckpt.time > best_sft_checkpoint[1]: |
| best_sft_checkpoint = (path, ckpt.time) |
| |
| return best_sft_checkpoint, best_rl_checkpoint |
|
|
|
|
| async def main(): |
| import sys |
| |
| if len(sys.argv) > 1 and sys.argv[1] == "eval": |
| |
| sft_ckpt, rl_ckpt = find_best_checkpoint() |
| |
| print("=" * 70) |
| print("CHECKPOINT EVALUATION") |
| print("=" * 70) |
| |
| if sft_ckpt: |
| print(f"\nBest SFT: {sft_ckpt[0]}") |
| await quick_eval(sft_ckpt[0], n_samples=50) |
| |
| if rl_ckpt: |
| print(f"\nBest RL: {rl_ckpt[0]}") |
| await quick_eval(rl_ckpt[0], n_samples=50) |
| else: |
| |
| list_training_runs() |
| |
| sft_ckpt, rl_ckpt = find_best_checkpoint() |
| |
| print("\n" + "=" * 70) |
| print("BEST CHECKPOINTS") |
| print("=" * 70) |
| |
| if sft_ckpt: |
| print(f"\nSFT: {sft_ckpt[0]}") |
| print(f" Time: {sft_ckpt[1]}") |
| |
| if rl_ckpt: |
| print(f"\nRL: {rl_ckpt[0]}") |
| print(f" Time: {rl_ckpt[1]}") |
| |
| print("\nTo evaluate, run: python training/monitor.py eval") |
|
|
|
|
| if __name__ == "__main__": |
| asyncio.run(main()) |
|
|
|
|