| """ |
| Quick test of the RL model. |
| """ |
|
|
| import asyncio |
| import json |
| import os |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
| import tinker |
| from tinker import types |
| from tinker_cookbook import renderers |
| from tinker_cookbook.tokenizer_utils import get_tokenizer |
|
|
| BASE_MODEL = "meta-llama/Llama-3.1-8B" |
|
|
| |
| RL_CHECKPOINT = "tinker://398393e1-7182-555d-aa1b-7ddf23892338:train:0/sampler_weights/rl_iter_005" |
|
|
| |
| SFT_CHECKPOINT = "tinker://398393e1-7182-555d-aa1b-7ddf23892338:train:0/sampler_weights/sft_final_sampler" |
|
|
| VALID_CATEGORIES = { |
| "company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts", |
| "company.business_priorities", "company.tools_config", "company.performance_context", |
| "user.communication_style", "user.strategic_approach", "user.role_context", |
| "user.workflow_patterns", "user.session_history", "user.interaction_preferences", |
| "none" |
| } |
|
|
|
|
| SYSTEM_PROMPT = """You route marketing conversations into structured memory categories. |
| |
| Available categories: |
| - company.brand_core: Voice, values, positioning, identity anchors (Long >1y) |
| - company.strategic_signatures: Decision frameworks, strategic heuristics (Long >1y) |
| - company.knowledge_artifacts: Docs, style guides, playbooks (Long >1y) |
| - company.business_priorities: Quarterly/seasonal goals, active campaigns (Short <3m) |
| - company.tools_config: Integrations, API keys, workflow settings (Medium ~6m) |
| - company.performance_context: Campaign metrics, retrospectives, learnings (Rolling ~6m) |
| - user.communication_style: Tone, verbosity, format expectations (Long >1y) |
| - user.strategic_approach: Personal priorities, success definitions (Long >1y) |
| - user.role_context: Title, scope, decision authority (Medium ~1y) |
| - user.workflow_patterns: Review cadence, collaboration norms (Medium ~1y) |
| - user.session_history: Immediate context, recent asks (Short <2w) |
| - user.interaction_preferences: Coaching style, feedback expectations (Evolving) |
| - none: Irrelevant, vague, or transactional content |
| |
| Respond with comma-separated categories. Use 'none' only if no other category applies.""" |
|
|
|
|
| async def test_model(checkpoint: str, name: str, test_examples: list): |
| """Test a model on examples.""" |
| print(f"\n{'='*60}") |
| print(f"TESTING: {name}") |
| print(f"Checkpoint: {checkpoint}") |
| print(f"{'='*60}") |
| |
| service_client = tinker.ServiceClient() |
| tokenizer = get_tokenizer(BASE_MODEL) |
| renderer = renderers.get_renderer(name="llama3", tokenizer=tokenizer) |
| |
| sampling_client = service_client.create_sampling_client(model_path=checkpoint) |
| stop_sequences = renderer.get_stop_sequences() |
| |
| results = [] |
| |
| for i, example in enumerate(test_examples): |
| messages = example.get("messages", []) |
| gold = example.get("categories", []) |
| |
| |
| conversation_text = "" |
| for m in messages: |
| role = m["role"].upper() |
| conversation_text += f"{role}: {m['content']}\n" |
| |
| prompt_messages = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": f"Conversation:\n{conversation_text}"} |
| ] |
| |
| prompt = renderer.build_generation_prompt(prompt_messages) |
| params = types.SamplingParams(max_tokens=100, temperature=0.1, stop=stop_sequences) |
| |
| result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result() |
| response, success = renderer.parse_response(result.sequences[0].tokens) |
| predicted = response["content"] if success else "" |
| |
| |
| predicted_set = set([c.strip().lower() for c in predicted.split(",") |
| if c.strip().lower() in VALID_CATEGORIES]) |
| gold_set = set([c.lower() for c in gold]) |
| |
| any_match = len(predicted_set & gold_set) > 0 if gold_set else (len(predicted_set) == 0) |
| exact_match = predicted_set == gold_set |
| |
| results.append({ |
| "any_match": any_match, |
| "exact_match": exact_match, |
| "predicted": predicted, |
| "gold": gold |
| }) |
| |
| |
| if i < 5: |
| print(f"\nExample {i+1}:") |
| print(f" Gold: {gold}") |
| print(f" Pred: {predicted}") |
| print(f" Match: {'Yes' if any_match else 'No'}") |
| |
| |
| any_match_rate = sum(r["any_match"] for r in results) / len(results) if results else 0 |
| exact_match_rate = sum(r["exact_match"] for r in results) / len(results) if results else 0 |
| |
| print(f"\n--- Results ({len(results)} examples) ---") |
| print(f"Any Match: {any_match_rate:.1%}") |
| print(f"Exact Match: {exact_match_rate:.1%}") |
| |
| return {"any_match": any_match_rate, "exact_match": exact_match_rate} |
|
|
|
|
| async def main(): |
| |
| print("=" * 60) |
| print("LOADING TEST DATA") |
| print("=" * 60) |
| |
| data = [] |
| with open("synthetic_data/training_dataset_1000.jsonl", "r") as f: |
| for line in f: |
| item = json.loads(line) |
| messages = [] |
| for turn in item.get("conversation", []): |
| if isinstance(turn, dict): |
| messages.append({"role": turn["role"], "content": turn["content"]}) |
| |
| |
| labels = item.get("labels", {}) |
| if isinstance(labels, dict): |
| categories = labels.get("categories", []) |
| elif isinstance(labels, list): |
| categories = labels |
| else: |
| categories = [] |
| |
| if not categories: |
| |
| scenario_id = item.get("scenario_id", "") |
| if "." in scenario_id: |
| cat = scenario_id.split("_")[0] |
| categories = [cat] |
| |
| data.append({ |
| "messages": messages, |
| "categories": categories |
| }) |
| |
| print(f"Total examples: {len(data)}") |
| |
| |
| test_data = data[-50:] |
| print(f"Test examples: {len(test_data)}") |
| |
| |
| rl_results = await test_model(RL_CHECKPOINT, "RL Model (5 iters)", test_data) |
| |
| |
| sft_results = await test_model(SFT_CHECKPOINT, "SFT Model", test_data) |
| |
| print("\n" + "=" * 60) |
| print("COMPARISON") |
| print("=" * 60) |
| print(f"SFT Any Match: {sft_results['any_match']:.1%}") |
| print(f"RL Any Match: {rl_results['any_match']:.1%}") |
| print(f"Improvement: {(rl_results['any_match'] - sft_results['any_match'])*100:+.1f}pp") |
|
|
|
|
| if __name__ == "__main__": |
| asyncio.run(main()) |
|
|
|
|