| """ |
| Memory Routing RL Environment |
| |
| This implements the MemoryRoutingEnv for Stage 2 (RL Optimization) per PRD Section 8. |
| |
| Per Tinker docs (rl/rl-envs.mdx): |
| - Env operates on tokens, not strings |
| - Implement initial_observation() and step() |
| - EnvGroupBuilder creates groups of environments |
| - RLDataset provides batches of EnvGroupBuilders |
| |
| Per PRD Section 4 (Reward Computation): |
| - R_F1: Token-level F1 between predicted and gold categories |
| - R_temp: Persistence alignment (+1.0 exact, +0.5 adjacent, 0.0 otherwise) |
| - R_parity: Company/user scope alignment |
| - R_eff: Storage efficiency (penalize >3 categories) |
| - R_total = 0.6*R_F1 + 0.2*R_temp + 0.1*R_parity + 0.1*R_eff |
| |
| Per PRD Section 4 (Environment Design): |
| - Single-step bandit: initial_observation returns conversation, step terminates |
| - EnvGroupBuilder clones each conversation across group_size rollouts |
| """ |
|
|
| import json |
| from typing import List, Dict, Any, Tuple, Set, Optional, Sequence |
| from dataclasses import dataclass |
|
|
| |
| VALID_CATEGORIES = { |
| "company.brand_core", |
| "company.strategic_signatures", |
| "company.knowledge_artifacts", |
| "company.business_priorities", |
| "company.tools_config", |
| "company.performance_context", |
| "user.communication_style", |
| "user.strategic_approach", |
| "user.role_context", |
| "user.workflow_patterns", |
| "user.session_history", |
| "user.interaction_preferences", |
| "none" |
| } |
|
|
| |
| CATEGORY_PERSISTENCE = { |
| "company.brand_core": "long", |
| "company.strategic_signatures": "long", |
| "company.knowledge_artifacts": "long", |
| "company.business_priorities": "short", |
| "company.tools_config": "medium", |
| "company.performance_context": "rolling", |
| "user.communication_style": "long", |
| "user.strategic_approach": "long", |
| "user.role_context": "medium", |
| "user.workflow_patterns": "medium", |
| "user.session_history": "short", |
| "user.interaction_preferences": "evolving", |
| "none": "short" |
| } |
|
|
| |
| CATEGORY_SCOPE = { |
| "company.brand_core": "company", |
| "company.strategic_signatures": "company", |
| "company.knowledge_artifacts": "company", |
| "company.business_priorities": "company", |
| "company.tools_config": "company", |
| "company.performance_context": "company", |
| "user.communication_style": "user", |
| "user.strategic_approach": "user", |
| "user.role_context": "user", |
| "user.workflow_patterns": "user", |
| "user.session_history": "user", |
| "user.interaction_preferences": "user", |
| "none": "none" |
| } |
|
|
|
|
| @dataclass |
| class RewardComponents: |
| """Breakdown of reward computation.""" |
| r_f1: float = 0.0 |
| r_temp: float = 0.0 |
| r_parity: float = 0.0 |
| r_eff: float = 0.0 |
| r_total: float = 0.0 |
| format_valid: bool = True |
|
|
|
|
| def parse_categories(text: str) -> Tuple[Set[str], bool]: |
| """ |
| Parse comma-separated categories from model output. |
| |
| Returns: |
| (set of valid categories, parse_success) |
| """ |
| if not text or not text.strip(): |
| return set(), False |
| |
| |
| raw_cats = [c.strip().lower() for c in text.split(",")] |
| |
| |
| valid_cats = {c for c in raw_cats if c in VALID_CATEGORIES} |
| |
| if not valid_cats: |
| return set(), False |
| |
| |
| |
| if "none" in valid_cats and len(valid_cats) > 1: |
| valid_cats.discard("none") |
| |
| return valid_cats, True |
|
|
|
|
| def compute_f1(predicted: Set[str], gold: Set[str]) -> float: |
| """ |
| Compute F1 score between predicted and gold category sets. |
| |
| Per PRD: Use macro-averaging if multi-label. |
| """ |
| if not predicted and not gold: |
| return 1.0 |
| if not predicted or not gold: |
| return 0.0 |
| |
| true_positives = len(predicted & gold) |
| precision = true_positives / len(predicted) if predicted else 0.0 |
| recall = true_positives / len(gold) if gold else 0.0 |
| |
| if precision + recall == 0: |
| return 0.0 |
| |
| return 2 * (precision * recall) / (precision + recall) |
|
|
|
|
| def compute_temporal_reward(predicted: Set[str], gold: Set[str]) -> float: |
| """ |
| Compute temporal alignment reward. |
| |
| Per PRD: |
| - +1.0 if predicted persistence matches gold |
| - +0.5 if adjacent (long<->medium or medium<->short) |
| - 0.0 otherwise |
| - Use majority vote if multi-label |
| """ |
| if not predicted or not gold: |
| return 0.0 |
| |
| |
| pred_persistence = [CATEGORY_PERSISTENCE.get(c, "medium") for c in predicted] |
| gold_persistence = [CATEGORY_PERSISTENCE.get(c, "medium") for c in gold] |
| |
| |
| def majority(items): |
| from collections import Counter |
| if not items: |
| return "medium" |
| counts = Counter(items) |
| return counts.most_common(1)[0][0] |
| |
| pred_pers = majority(pred_persistence) |
| gold_pers = majority(gold_persistence) |
| |
| |
| if pred_pers == gold_pers: |
| return 1.0 |
| |
| |
| adjacency = { |
| ("long", "medium"): True, |
| ("medium", "long"): True, |
| ("medium", "short"): True, |
| ("short", "medium"): True, |
| ("medium", "rolling"): True, |
| ("rolling", "medium"): True, |
| ("short", "rolling"): True, |
| ("rolling", "short"): True, |
| } |
| |
| if (pred_pers, gold_pers) in adjacency: |
| return 0.5 |
| |
| return 0.0 |
|
|
|
|
| def compute_parity_reward(predicted: Set[str], gold: Set[str]) -> float: |
| """ |
| Compute company/user scope alignment reward. |
| |
| Per PRD: |
| - +1.0 if predicted scope matches gold scope exactly |
| - 0.0 otherwise |
| """ |
| def get_scope(categories: Set[str]) -> str: |
| scopes = {CATEGORY_SCOPE.get(c, "none") for c in categories} |
| if "company" in scopes and "user" in scopes: |
| return "mixed" |
| elif "company" in scopes: |
| return "company" |
| elif "user" in scopes: |
| return "user" |
| else: |
| return "none" |
| |
| pred_scope = get_scope(predicted) |
| gold_scope = get_scope(gold) |
| |
| return 1.0 if pred_scope == gold_scope else 0.0 |
|
|
|
|
| def compute_efficiency_reward(predicted: Set[str]) -> float: |
| """ |
| Compute storage efficiency reward. |
| |
| Per PRD: |
| - 1.0 if ≤3 categories |
| - 0.7 if 4 categories |
| - 0.4 if 5 categories |
| - 0.0 if ≥6 categories |
| """ |
| n = len(predicted) |
| if n <= 3: |
| return 1.0 |
| elif n == 4: |
| return 0.7 |
| elif n == 5: |
| return 0.4 |
| else: |
| return 0.0 |
|
|
|
|
| def compute_reward(predicted_text: str, gold_categories: List[str]) -> RewardComponents: |
| """ |
| Compute full reward for a prediction. |
| |
| Per PRD Section 4: |
| R_total = 0.6 * R_F1 + 0.2 * R_temp + 0.1 * R_parity + 0.1 * R_eff |
| |
| Returns RewardComponents with breakdown. |
| """ |
| result = RewardComponents() |
| |
| |
| predicted, parse_success = parse_categories(predicted_text) |
| gold = set(gold_categories) |
| |
| |
| if not parse_success: |
| result.format_valid = False |
| result.r_total = -1.0 |
| return result |
| |
| |
| result.r_f1 = compute_f1(predicted, gold) |
| result.r_temp = compute_temporal_reward(predicted, gold) |
| result.r_parity = compute_parity_reward(predicted, gold) |
| result.r_eff = compute_efficiency_reward(predicted) |
| |
| |
| result.r_total = ( |
| 0.6 * result.r_f1 + |
| 0.2 * result.r_temp + |
| 0.1 * result.r_parity + |
| 0.1 * result.r_eff |
| ) |
| |
| return result |
|
|
|
|
| |
| |
|
|
| class MemoryRoutingEnv: |
| """ |
| Single-step bandit environment for memory routing. |
| |
| Per Tinker Env interface: |
| - initial_observation() -> (Observation, StopCondition) |
| - step(action) -> StepResult |
| |
| Per PRD: Single-step episodes - step() terminates immediately with reward. |
| """ |
| |
| def __init__( |
| self, |
| conversation: List[Dict[str, str]], |
| gold_categories: List[str], |
| prompt_tokens: List[int], |
| stop_tokens: List[int], |
| scenario_id: str = "" |
| ): |
| self.conversation = conversation |
| self.gold_categories = gold_categories |
| self.prompt_tokens = prompt_tokens |
| self.stop_tokens = stop_tokens |
| self.scenario_id = scenario_id |
| self._done = False |
| |
| async def initial_observation(self): |
| """ |
| Return the initial observation (prompt tokens) and stop condition. |
| |
| Per Tinker: Returns (Observation, StopCondition) |
| - Observation is the model input (tokens) |
| - StopCondition tells the sampler when to stop |
| """ |
| from tinker import types |
| from tinker_cookbook.rl.types import StopCondition |
| |
| observation = types.ModelInput.from_ints(self.prompt_tokens) |
| stop_condition = StopCondition(stop_tokens=self.stop_tokens) |
| |
| return observation, stop_condition |
| |
| async def step(self, action): |
| """ |
| Process the model's action (generated tokens) and return reward. |
| |
| Per Tinker: Returns StepResult with reward and done=True |
| Per PRD: Single-step bandit, so always terminates. |
| """ |
| from tinker_cookbook.rl.types import StepResult |
| |
| |
| |
| |
| if isinstance(action, list): |
| |
| action_text = str(action) |
| else: |
| action_text = str(action) |
| |
| |
| reward_components = compute_reward(action_text, self.gold_categories) |
| |
| self._done = True |
| |
| return StepResult( |
| reward=reward_components.r_total, |
| done=True, |
| info={ |
| "r_f1": reward_components.r_f1, |
| "r_temp": reward_components.r_temp, |
| "r_parity": reward_components.r_parity, |
| "r_eff": reward_components.r_eff, |
| "format_valid": reward_components.format_valid, |
| "scenario_id": self.scenario_id |
| } |
| ) |
|
|
|
|
| class MemoryRoutingEnvGroupBuilder: |
| """ |
| Builds a group of identical environments for variance reduction. |
| |
| Per Tinker docs (rl/rl-envs.mdx): |
| - EnvGroupBuilder creates group_size copies of the same environment |
| - This enables comparing multiple samples for the same input |
| """ |
| |
| def __init__( |
| self, |
| conversation: List[Dict[str, str]], |
| gold_categories: List[str], |
| prompt_tokens: List[int], |
| stop_tokens: List[int], |
| group_size: int = 8, |
| scenario_id: str = "" |
| ): |
| self.conversation = conversation |
| self.gold_categories = gold_categories |
| self.prompt_tokens = prompt_tokens |
| self.stop_tokens = stop_tokens |
| self.group_size = group_size |
| self.scenario_id = scenario_id |
| |
| async def make_envs(self) -> Sequence["MemoryRoutingEnv"]: |
| """Create group_size copies of the environment.""" |
| return [ |
| MemoryRoutingEnv( |
| conversation=self.conversation, |
| gold_categories=self.gold_categories, |
| prompt_tokens=self.prompt_tokens, |
| stop_tokens=self.stop_tokens, |
| scenario_id=self.scenario_id |
| ) |
| for _ in range(self.group_size) |
| ] |
| |
| def logging_tags(self) -> Dict[str, Any]: |
| """Return tags for logging.""" |
| return { |
| "scenario_id": self.scenario_id, |
| "num_gold_categories": len(self.gold_categories), |
| "has_none": "none" in self.gold_categories |
| } |
|
|
|
|
| class MemoryRoutingDataset: |
| """ |
| Dataset of EnvGroupBuilders for RL training. |
| |
| Per Tinker docs (rl/rl-envs.mdx): |
| - RLDataset.get_batch(index) returns list of EnvGroupBuilders |
| """ |
| |
| def __init__( |
| self, |
| examples: List[Dict[str, Any]], |
| batch_size: int, |
| group_size: int, |
| renderer, |
| tokenizer |
| ): |
| self.examples = examples |
| self.batch_size = batch_size |
| self.group_size = group_size |
| self.renderer = renderer |
| self.tokenizer = tokenizer |
| self.stop_tokens = renderer.get_stop_sequences() |
| |
| def __len__(self) -> int: |
| return len(self.examples) // self.batch_size |
| |
| def get_batch(self, index: int) -> List[MemoryRoutingEnvGroupBuilder]: |
| """Get a batch of EnvGroupBuilders.""" |
| start_idx = (index * self.batch_size) % len(self.examples) |
| end_idx = start_idx + self.batch_size |
| |
| if end_idx <= len(self.examples): |
| batch_examples = self.examples[start_idx:end_idx] |
| else: |
| batch_examples = self.examples[start_idx:] |
| batch_examples.extend(self.examples[:end_idx - len(self.examples)]) |
| |
| builders = [] |
| for example in batch_examples: |
| |
| messages = example.get("messages", []) |
| if not messages: |
| |
| conversation = example.get("conversation", []) |
| categories = example.get("labels", {}).get("categories", []) |
| |
| from training.preprocess import build_routing_prompt |
| full_messages = build_routing_prompt(conversation, categories) |
| |
| messages = full_messages[:-1] |
| |
| |
| prompt = self.renderer.build_generation_prompt(messages) |
| prompt_tokens = prompt.to_ints() |
| |
| |
| gold_categories = example.get("categories", []) |
| if not gold_categories: |
| gold_categories = example.get("labels", {}).get("categories", []) |
| |
| builders.append(MemoryRoutingEnvGroupBuilder( |
| conversation=example.get("conversation", []), |
| gold_categories=gold_categories, |
| prompt_tokens=prompt_tokens, |
| stop_tokens=self.stop_tokens, |
| group_size=self.group_size, |
| scenario_id=example.get("scenario_id", "") |
| )) |
| |
| return builders |
|
|
|
|
| class MemoryRoutingDatasetBuilder: |
| """ |
| Factory for creating train/test RL datasets. |
| |
| Per Tinker pattern from math_env.py example. |
| """ |
| |
| def __init__( |
| self, |
| train_data_path: str, |
| test_data_path: str, |
| batch_size: int = 64, |
| group_size: int = 8, |
| model_name: str = "meta-llama/Llama-3.1-8B", |
| renderer_name: str = "llama3" |
| ): |
| self.train_data_path = train_data_path |
| self.test_data_path = test_data_path |
| self.batch_size = batch_size |
| self.group_size = group_size |
| self.model_name = model_name |
| self.renderer_name = renderer_name |
| |
| def __call__(self) -> Tuple[MemoryRoutingDataset, MemoryRoutingDataset]: |
| """Create train and test datasets.""" |
| from tinker_cookbook import renderers, tokenizer_utils |
| |
| tokenizer = tokenizer_utils.get_tokenizer(self.model_name) |
| renderer = renderers.get_renderer(name=self.renderer_name, tokenizer=tokenizer) |
| |
| |
| with open(self.train_data_path, "r") as f: |
| train_examples = json.load(f) |
| |
| with open(self.test_data_path, "r") as f: |
| test_examples = json.load(f) |
| |
| train_dataset = MemoryRoutingDataset( |
| examples=train_examples, |
| batch_size=self.batch_size, |
| group_size=self.group_size, |
| renderer=renderer, |
| tokenizer=tokenizer |
| ) |
| |
| test_dataset = MemoryRoutingDataset( |
| examples=test_examples, |
| batch_size=min(self.batch_size, len(test_examples)), |
| group_size=self.group_size, |
| renderer=renderer, |
| tokenizer=tokenizer |
| ) |
| |
| return train_dataset, test_dataset |
|
|
|
|
| |
| if __name__ == "__main__": |
| |
| test_cases = [ |
| |
| ("company.brand_core, user.strategic_approach", ["company.brand_core", "user.strategic_approach"], True), |
| ("none", ["none"], True), |
| ("company.brand_core, none", ["company.brand_core"], True), |
| ("invalid_category", ["company.brand_core"], False), |
| ("", ["company.brand_core"], False), |
| ("company.brand_core", ["company.brand_core", "user.role_context"], True), |
| ] |
| |
| print("Testing reward computation:") |
| print("=" * 60) |
| |
| for pred, gold, expected_valid in test_cases: |
| result = compute_reward(pred, gold) |
| print(f"\nPredicted: '{pred}'") |
| print(f"Gold: {gold}") |
| print(f"Format valid: {result.format_valid} (expected: {expected_valid})") |
| print(f"R_F1: {result.r_f1:.3f}") |
| print(f"R_temp: {result.r_temp:.3f}") |
| print(f"R_parity: {result.r_parity:.3f}") |
| print(f"R_eff: {result.r_eff:.3f}") |
| print(f"R_total: {result.r_total:.3f}") |
|
|
|
|