chat_env / models.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
6183f0f verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Data models for the Chat Environment.
The Chat environment provides a chat-based interface for LLMs with support
for tokenization and message history management.
"""
from openenv.core.env_server.types import Action, Observation, State
from pydantic import Field, field_validator
def _flatten_tokens(value) -> list[int]:
"""Coerce nested tensor-like or sequence inputs into a flat token list."""
if hasattr(value, "tolist") and callable(value.tolist):
value = value.tolist()
if isinstance(value, tuple):
value = list(value)
if isinstance(value, list):
flattened: list[int] = []
for item in value:
flattened.extend(_flatten_tokens(item))
return flattened
return [int(value)]
class ChatAction(Action):
"""Action for chat environments.
Contains tokens that represent the action to be taken.
This interfaces directly with models.
"""
tokens: list[int] = Field(..., min_length=1)
@field_validator("tokens", mode="before")
@classmethod
def _coerce_tokens(cls, value):
"""Accept either tensors or JSON arrays on the public HTTP surface."""
if isinstance(value, (list, tuple)) or hasattr(value, "tolist"):
return _flatten_tokens(value)
raise TypeError("tokens must be provided as a sequence of token ids")
class ChatState(State):
"""State of the ChatEnvironment containing message history."""
# TODO: revert to list[Message] once openenv-core ships typing_extensions.TypedDict
# in interfaces.py and chat_env/pyproject.toml pins to that release.
history_messages: list[dict[str, str]] = Field(default_factory=list)
history_tokens: list[list[int]] = Field(default_factory=list) # Same len as messages
class ChatObservation(Observation):
"""Observation returned by ChatEnvironment.
Contains the message history in Huggingface format (list of dicts with role/content)
and the tokenized representation of the entire conversation.
The environment owns the tokenizer and generates the tokens from the messages.
Example:
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "How tall is the Eiffel Tower?"},
]
tokens = tensor([1, 2, 3, 4, 5, ...]) # tokenized entire conversation
"""
# TODO: revert to list[Message] (same as above)
messages: list[dict[str, str]] = Field(default_factory=list)
tokens: list[int] = Field(default_factory=list)
# Inherited Fields from Observation ABC: reward, done, metadata