SavirD's picture
Upload folder using huggingface_hub
4d2821f verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Task registry for meta-learning.
Tasks can be from the internal registry (get_task(task_id)) or provided from outside
via task_spec_from_dict() — the client sends the task definition to the environment.
Supports sinusoid (regression) and SLM (next-token prediction) task types.
"""
from dataclasses import dataclass
from typing import Any, Dict, List
import math
from .slm_model import DEFAULT_VOCAB_SIZE as SLM_DEFAULT_VOCAB_SIZE
# Distribution A: 50 training tasks (low-freq sinusoids)
TRAIN_TASK_IDS: List[int] = list(range(50))
# Distribution B: held-out eval tasks (high-freq sinusoids — different distribution)
EVAL_TASK_IDS: List[int] = [50, 51]
# SLM: 50 train tasks, 2 eval (different corpus split or seed range)
SLM_TRAIN_TASK_IDS: List[int] = list(range(50))
SLM_EVAL_TASK_IDS: List[int] = [50, 51]
# Fixed small corpus for SLM (character-level). ~10KB so tasks are reproducible.
DEFAULT_CORPUS: str = (
"The quick brown fox jumps over the lazy dog. "
"Pack my box with five dozen liquor jugs. "
"How vexingly quick daft zebras jump. "
"Sphinx of black quartz, judge my vow. "
"The five boxing wizards jump quickly. "
) * 200 # repeat to get enough length for sampling
# Bounds for each distribution (freq, amplitude, phase)
DIST_A_FREQ = (1.0, 3.0)
DIST_A_AMP = (0.5, 2.0)
DIST_B_FREQ = (4.0, 6.0)
DIST_B_AMP = (0.3, 1.5)
@dataclass
class TaskSpec:
"""Spec for one sinusoidal regression task."""
task_id: int
input_dim: int # 1 for scalar sinusoid input
hidden_dim: int
output_dim: int
data_seed: int
arch_seed: int
# Sinusoidal target: y = amplitude * sin(2*pi*freq*x + phase)
amplitude: float
freq: float
phase: float
distribution: str # "A" or "B"
@dataclass
class SLMTaskSpec:
"""Spec for one SLM (next-token prediction) task."""
task_id: int
data_seed: int
arch_seed: int
vocab_size: int
context_len: int # block size
n_layer: int
n_head: int
n_embd: int
corpus_id: str # e.g. "default"
distribution: str # "A" or "B" or "external"
def get_task(task_id: int) -> TaskSpec:
"""
Return the task spec for the given task_id.
Task IDs 0..49 = Distribution A (train), 50+ = Distribution B (eval).
"""
if task_id < 0:
raise ValueError(f"task_id must be >= 0, got {task_id}")
r = task_id * 7919 + 1
data_seed = task_id * 31337
arch_seed = task_id * 131 + 7
hidden_dim = 32 + (r % 33)
if task_id < 50:
# Distribution A
f_lo, f_hi = DIST_A_FREQ
a_lo, a_hi = DIST_A_AMP
distribution = "A"
else:
# Distribution B
f_lo, f_hi = DIST_B_FREQ
a_lo, a_hi = DIST_B_AMP
distribution = "B"
# Deterministic but varied per task
freq = f_lo + (r % 1000) / 1000.0 * (f_hi - f_lo)
amplitude = a_lo + ((r * 3) % 1000) / 1000.0 * (a_hi - a_lo)
phase = ((r * 7) % 1000) / 1000.0 * 2 * math.pi
return TaskSpec(
task_id=task_id,
input_dim=1,
hidden_dim=hidden_dim,
output_dim=1,
data_seed=data_seed,
arch_seed=arch_seed,
amplitude=amplitude,
freq=freq,
phase=phase,
distribution=distribution,
)
def get_slm_task(task_id: int) -> SLMTaskSpec:
"""
Return the SLM task spec for the given task_id.
Task IDs 0..49 = Distribution A (train), 50+ = Distribution B (eval).
"""
if task_id < 0:
raise ValueError(f"task_id must be >= 0, got {task_id}")
r = task_id * 7919 + 1
data_seed = task_id * 31337
arch_seed = task_id * 131 + 7
if task_id < 50:
distribution = "A"
else:
distribution = "B"
return SLMTaskSpec(
task_id=task_id,
data_seed=data_seed,
arch_seed=arch_seed,
vocab_size=SLM_DEFAULT_VOCAB_SIZE,
context_len=64,
n_layer=2,
n_head=4,
n_embd=128,
corpus_id="default",
distribution=distribution,
)
def slm_task_spec_from_dict(d: Dict[str, Any]) -> SLMTaskSpec:
"""Build an SLMTaskSpec from an external dict (type='slm')."""
task_id = int(d.get("task_id", 0))
return SLMTaskSpec(
task_id=task_id,
data_seed=int(d.get("data_seed", task_id * 31337)),
arch_seed=int(d.get("arch_seed", task_id * 131 + 7)),
vocab_size=int(d.get("vocab_size", SLM_DEFAULT_VOCAB_SIZE)),
context_len=int(d.get("context_len", 64)),
n_layer=int(d.get("n_layer", 2)),
n_head=int(d.get("n_head", 4)),
n_embd=int(d.get("n_embd", 128)),
corpus_id=str(d.get("corpus_id", "default")),
distribution=d.get("distribution", "external"),
)
def task_spec_from_dict(d: Dict[str, Any]) -> TaskSpec | SLMTaskSpec:
"""
Build a TaskSpec or SLMTaskSpec from an external dict (sent by the client).
For type "sinusoid": amplitude, freq, phase, data_seed (optional), arch_seed (optional), etc.
For type "slm": vocab_size, context_len, n_layer, n_head, n_embd (all optional with defaults).
"""
task_type = d.get("type", "slm")
if task_type == "sinusoid":
task_id = d.get("task_id", 0)
return TaskSpec(
task_id=task_id,
input_dim=int(d.get("input_dim", 1)),
hidden_dim=int(d.get("hidden_dim", 32)),
output_dim=1,
data_seed=int(d.get("data_seed", task_id * 31337)),
arch_seed=int(d.get("arch_seed", task_id * 131 + 7)),
amplitude=float(d["amplitude"]),
freq=float(d["freq"]),
phase=float(d["phase"]),
distribution=d.get("distribution", "external"),
)
if task_type == "slm":
return slm_task_spec_from_dict(d)
raise ValueError(f"Unknown task type: {task_type!r}")