Muqeeth's picture
Add files using upload-large-folder tool
9ba32f5 verified
"""
File: mllm/training/tally_rollout.py
Summary: Serializes rollout data into tallies for downstream processing.
"""
import json
import os
from copy import deepcopy
from typing import Union
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer
class RolloutTallyItem:
def __init__(
self,
crn_ids: list[str],
rollout_ids: list[str],
agent_ids: list[str],
metric_matrix: torch.Tensor,
):
"""Lightweight data container that keeps rollout-aligned metric matrices."""
if isinstance(crn_ids, torch.Tensor):
crn_ids = crn_ids.detach().cpu().numpy()
if isinstance(rollout_ids, torch.Tensor):
rollout_ids = rollout_ids.detach().cpu().numpy()
if isinstance(agent_ids, torch.Tensor):
agent_ids = agent_ids.detach().cpu().numpy()
self.crn_ids = crn_ids
self.rollout_ids = rollout_ids
self.agent_ids = agent_ids
metric_matrix = metric_matrix.detach().cpu()
assert (
0 < metric_matrix.ndim <= 2
), "Metric matrix must have less than or equal to 2 dimensions"
if metric_matrix.ndim == 1:
metric_matrix = metric_matrix.reshape(1, -1)
# Convert to float32 if tensor is in BFloat16 format (not supported by numpy)
if metric_matrix.dtype == torch.bfloat16:
metric_matrix = metric_matrix.float()
self.metric_matrix = metric_matrix.numpy()
class RolloutTally:
"""
Tally is a utility class for collecting and storing training metrics.
It supports adding metrics at specified paths and saving them to disk.
"""
def __init__(self):
"""
Initializes the RolloutTally object.
Args:
tokenizer (AutoTokenizer): Tokenizer for converting token IDs to strings.
max_context_length (int, optional): Maximum context length for contextualized metrics. Defaults to 30.
"""
# Array-preserving structure (leaf lists hold numpy arrays / scalars)
self.metrics = {}
# Global ordered list of sample identifiers (crn_id, rollout_id) added in the order samples are processed
def reset(self):
"""Reset the tally to an empty dict."""
self.metrics = {}
def get_from_nested_dict(self, dictio: dict, path: str):
"""Retrieve a nested entry, creating intermediate dicts as needed."""
assert isinstance(path, list), "Path must be list."
for sp in path[:-1]:
dictio = dictio.setdefault(sp, {})
return dictio.get(path[-1], None)
def set_at_path(self, dictio: dict, path: str, value):
"""Store ``value`` at ``path``; helper used by ``add_metric``."""
for sp in path[:-1]:
dictio = dictio.setdefault(sp, {})
dictio[path[-1]] = value
def add_metric(self, path: list[str], rollout_tally_item: RolloutTallyItem):
"""
Adds a metric to the base tally at the specified path.
Args:
path (list): List of keys representing the path in the base tally.
rollout_tally_item (RolloutTallyItem): The rollout tally item to add.
"""
rollout_tally_item = deepcopy(rollout_tally_item)
# Update array-preserving tally
array_list = self.get_from_nested_dict(dictio=self.metrics, path=path)
if array_list is None:
self.set_at_path(dictio=self.metrics, path=path, value=[rollout_tally_item])
else:
array_list.append(rollout_tally_item)
def save(self, identifier: str, folder: str):
"""Persist the tally as a pickle (metrics only) under ``folder``."""
os.makedirs(name=folder, exist_ok=True)
from datetime import datetime
now = datetime.now()
# Pickle only (fastest, exact structure with numpy/scalars at leaves)
try:
import pickle
pkl_path = os.path.join(folder, f"{identifier}.rt_tally.pkl")
payload = {"metrics": self.metrics}
with open(pkl_path, "wb") as f:
pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL)
except Exception:
pass