| |
| |
| |
| |
| |
|
|
| """ |
| Task registry for meta-learning. |
| |
| Tasks can be from the internal registry (get_task(task_id)) or provided from outside |
| via task_spec_from_dict() — the client sends the task definition to the environment. |
| Supports sinusoid (regression) and SLM (next-token prediction) task types. |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import Any, Dict, List |
|
|
| import math |
|
|
| from .slm_model import DEFAULT_VOCAB_SIZE as SLM_DEFAULT_VOCAB_SIZE |
|
|
| |
| TRAIN_TASK_IDS: List[int] = list(range(50)) |
|
|
| |
| EVAL_TASK_IDS: List[int] = [50, 51] |
|
|
| |
| SLM_TRAIN_TASK_IDS: List[int] = list(range(50)) |
| SLM_EVAL_TASK_IDS: List[int] = [50, 51] |
|
|
| |
| DEFAULT_CORPUS: str = ( |
| "The quick brown fox jumps over the lazy dog. " |
| "Pack my box with five dozen liquor jugs. " |
| "How vexingly quick daft zebras jump. " |
| "Sphinx of black quartz, judge my vow. " |
| "The five boxing wizards jump quickly. " |
| ) * 200 |
|
|
| |
| DIST_A_FREQ = (1.0, 3.0) |
| DIST_A_AMP = (0.5, 2.0) |
| DIST_B_FREQ = (4.0, 6.0) |
| DIST_B_AMP = (0.3, 1.5) |
|
|
|
|
| @dataclass |
| class TaskSpec: |
| """Spec for one sinusoidal regression task.""" |
|
|
| task_id: int |
| input_dim: int |
| hidden_dim: int |
| output_dim: int |
| data_seed: int |
| arch_seed: int |
| |
| amplitude: float |
| freq: float |
| phase: float |
| distribution: str |
|
|
|
|
| @dataclass |
| class SLMTaskSpec: |
| """Spec for one SLM (next-token prediction) task.""" |
|
|
| task_id: int |
| data_seed: int |
| arch_seed: int |
| vocab_size: int |
| context_len: int |
| n_layer: int |
| n_head: int |
| n_embd: int |
| corpus_id: str |
| distribution: str |
|
|
|
|
| def get_task(task_id: int) -> TaskSpec: |
| """ |
| Return the task spec for the given task_id. |
| Task IDs 0..49 = Distribution A (train), 50+ = Distribution B (eval). |
| """ |
| if task_id < 0: |
| raise ValueError(f"task_id must be >= 0, got {task_id}") |
| r = task_id * 7919 + 1 |
| data_seed = task_id * 31337 |
| arch_seed = task_id * 131 + 7 |
| hidden_dim = 32 + (r % 33) |
|
|
| if task_id < 50: |
| |
| f_lo, f_hi = DIST_A_FREQ |
| a_lo, a_hi = DIST_A_AMP |
| distribution = "A" |
| else: |
| |
| f_lo, f_hi = DIST_B_FREQ |
| a_lo, a_hi = DIST_B_AMP |
| distribution = "B" |
|
|
| |
| freq = f_lo + (r % 1000) / 1000.0 * (f_hi - f_lo) |
| amplitude = a_lo + ((r * 3) % 1000) / 1000.0 * (a_hi - a_lo) |
| phase = ((r * 7) % 1000) / 1000.0 * 2 * math.pi |
|
|
| return TaskSpec( |
| task_id=task_id, |
| input_dim=1, |
| hidden_dim=hidden_dim, |
| output_dim=1, |
| data_seed=data_seed, |
| arch_seed=arch_seed, |
| amplitude=amplitude, |
| freq=freq, |
| phase=phase, |
| distribution=distribution, |
| ) |
|
|
|
|
| def get_slm_task(task_id: int) -> SLMTaskSpec: |
| """ |
| Return the SLM task spec for the given task_id. |
| Task IDs 0..49 = Distribution A (train), 50+ = Distribution B (eval). |
| """ |
| if task_id < 0: |
| raise ValueError(f"task_id must be >= 0, got {task_id}") |
| r = task_id * 7919 + 1 |
| data_seed = task_id * 31337 |
| arch_seed = task_id * 131 + 7 |
| if task_id < 50: |
| distribution = "A" |
| else: |
| distribution = "B" |
| return SLMTaskSpec( |
| task_id=task_id, |
| data_seed=data_seed, |
| arch_seed=arch_seed, |
| vocab_size=SLM_DEFAULT_VOCAB_SIZE, |
| context_len=64, |
| n_layer=2, |
| n_head=4, |
| n_embd=128, |
| corpus_id="default", |
| distribution=distribution, |
| ) |
|
|
|
|
| def slm_task_spec_from_dict(d: Dict[str, Any]) -> SLMTaskSpec: |
| """Build an SLMTaskSpec from an external dict (type='slm').""" |
| task_id = int(d.get("task_id", 0)) |
| return SLMTaskSpec( |
| task_id=task_id, |
| data_seed=int(d.get("data_seed", task_id * 31337)), |
| arch_seed=int(d.get("arch_seed", task_id * 131 + 7)), |
| vocab_size=int(d.get("vocab_size", SLM_DEFAULT_VOCAB_SIZE)), |
| context_len=int(d.get("context_len", 64)), |
| n_layer=int(d.get("n_layer", 2)), |
| n_head=int(d.get("n_head", 4)), |
| n_embd=int(d.get("n_embd", 128)), |
| corpus_id=str(d.get("corpus_id", "default")), |
| distribution=d.get("distribution", "external"), |
| ) |
|
|
|
|
| def task_spec_from_dict(d: Dict[str, Any]) -> TaskSpec | SLMTaskSpec: |
| """ |
| Build a TaskSpec or SLMTaskSpec from an external dict (sent by the client). |
| |
| For type "sinusoid": amplitude, freq, phase, data_seed (optional), arch_seed (optional), etc. |
| For type "slm": vocab_size, context_len, n_layer, n_head, n_embd (all optional with defaults). |
| """ |
| task_type = d.get("type", "slm") |
| if task_type == "sinusoid": |
| task_id = d.get("task_id", 0) |
| return TaskSpec( |
| task_id=task_id, |
| input_dim=int(d.get("input_dim", 1)), |
| hidden_dim=int(d.get("hidden_dim", 32)), |
| output_dim=1, |
| data_seed=int(d.get("data_seed", task_id * 31337)), |
| arch_seed=int(d.get("arch_seed", task_id * 131 + 7)), |
| amplitude=float(d["amplitude"]), |
| freq=float(d["freq"]), |
| phase=float(d["phase"]), |
| distribution=d.get("distribution", "external"), |
| ) |
| if task_type == "slm": |
| return slm_task_spec_from_dict(d) |
| raise ValueError(f"Unknown task type: {task_type!r}") |
|
|