| from typing import List, Tuple, Union |
|
|
| import einops |
| import numpy as np |
| import torch |
| from transformers import AutoModel, PreTrainedModel |
| from vector_quantize_pytorch import VectorQuantize |
|
|
| from .configuration_actioncodec import ActionCodecConfig |
| from .modular_actioncodec import PerceiverDecoder, PerceiverEncoder |
| from .rvq import ResidualVectorQuantize |
|
|
|
|
| def trim_trailing_zeros(arr: np.ndarray) -> list[np.ndarray]: |
| if arr.shape[0] == 0: |
| return [] |
|
|
| b, n = arr.shape |
|
|
| is_nonzero = arr != 0 |
| flipped_mask = np.flip(is_nonzero, axis=1) |
| last_nonzero_indices = n - 1 - np.argmax(flipped_mask, axis=1) |
| any_nonzero_in_row = is_nonzero.any(axis=1) |
| new_lengths = (last_nonzero_indices + 1) * any_nonzero_in_row |
| result = [arr[i, :length].tolist() for i, length in enumerate(new_lengths)] |
|
|
| return result |
|
|
|
|
| class ActionCodec(PreTrainedModel): |
| """ActionCodec: A neural codec for encoding and decoding robot action sequences. |
| |
| This model uses a Perceiver-based encoder-decoder architecture with vector quantization |
| to convert continuous action sequences into discrete token sequences. It supports |
| multiple robot embodiments with different action dimensions and control frequencies. |
| |
| The model supports two vector quantization types: |
| - VQ (Vector Quantization): Single quantizer |
| - RVQ (Residual Vector Quantization): Multiple quantizers for hierarchical encoding |
| |
| Key features: |
| - Multi-embodiment support: Handle different robots with varying action dimensions |
| - Dynamic expansion: Add new robot configurations without retraining |
| - Flexible input/output: Support numpy arrays and torch tensors |
| """ |
|
|
| config_class = ActionCodecConfig |
|
|
| def __init__(self, config: ActionCodecConfig): |
| """Initialize the ActionCodec model. |
| |
| Args: |
| config (ActionCodecConfig): Model configuration containing hyperparameters |
| and embodiment configurations. |
| |
| Raises: |
| ValueError: If configuration parameters are invalid. |
| NotImplementedError: If the specified VQ type is not supported. |
| """ |
| super().__init__(config) |
|
|
| |
| if config.n_tokens % config.n_quantizers != 0: |
| raise ValueError(f"n_tokens ({config.n_tokens}) must be divisible by n_quantizers ({config.n_quantizers})") |
|
|
| if config.n_quantizers < 1: |
| raise ValueError(f"n_quantizers must be at least 1, got {config.n_quantizers}") |
|
|
| if config.vq_codebook_size < 1: |
| raise ValueError(f"vq_codebook_size must be at least 1, got {config.vq_codebook_size}") |
|
|
| if config.z_dim < 1: |
| raise ValueError(f"z_dim must be at least 1, got {config.z_dim}") |
|
|
| if not isinstance(config.embodiment_config, dict) or len(config.embodiment_config) == 0: |
| raise ValueError( |
| "embodiment_config must be a non-empty dictionary mapping embodiment names to configurations" |
| ) |
|
|
| self.default_embodiment_id = 0 |
|
|
| |
| self.encoder = PerceiverEncoder(config) |
| self.decoder = PerceiverDecoder(config) |
|
|
| |
| if config.vq_type == "vq": |
| if config.n_quantizers != 1: |
| raise ValueError( |
| f"VQ type requires n_quantizers=1, got {config.n_quantizers}. Use RVQ type for multiple quantizers." |
| ) |
| self.vq = VectorQuantize( |
| dim=config.z_dim, |
| codebook_size=config.vq_codebook_size, |
| commitment_weight=config.vq_commitment_weight, |
| decay=config.vq_decay, |
| kmeans_init=config.vq_kmeans_init, |
| threshold_ema_dead_code=config.vq_threshold_ema_dead_code, |
| rotation_trick=False, |
| straight_through=True, |
| ) |
| elif config.vq_type == "rvq": |
| if config.n_quantizers < 2: |
| raise ValueError( |
| f"RVQ type requires n_quantizers >= 2, got {config.n_quantizers}. Use VQ type for single quantizer." |
| ) |
| self.vq = ResidualVectorQuantize( |
| dim=config.z_dim, |
| n_codebooks=config.n_quantizers, |
| codebook_size=config.vq_codebook_size, |
| codebook_dim=config.z_dim, |
| quantizer_dropout=config.vq_quantizer_dropout, |
| commitment=config.vq_commitment_weight, |
| ) |
| else: |
| raise NotImplementedError(f"VQ type '{config.vq_type}' not implemented. Supported types: 'vq', 'rvq'") |
|
|
| |
| self.vocab_size = config.vq_codebook_size |
| self.num_quantizers = config.n_quantizers |
| self.n_tokens_per_quantizer = config.n_tokens // config.n_quantizers |
|
|
| def expand_embodiment(self, embodiment_config: dict): |
| """Dynamically expand the model to support new robot embodiments. |
| |
| This method allows adding new robot configurations to the codec without retraining |
| the entire model. It updates the encoder and decoder to handle the new action dimensions |
| and frequencies while preserving existing functionality for previously configured robots. |
| |
| Args: |
| embodiment_config (dict): Dictionary mapping embodiment names to their configurations. |
| Each configuration should be a dict with keys: |
| - "action_dim" (int): Action dimensionality for this embodiment. |
| - "freq" (float): Control frequency in Hz. |
| - "duration" (float): Default action sequence duration in seconds. |
| - "description" (str, optional): Human-readable description. |
| |
| Example: |
| { |
| "robot_B": { |
| "action_dim": 10, |
| "freq": 20, |
| "duration": 1.0, |
| "description": "10-dim robot at 20Hz" |
| } |
| } |
| |
| Returns: |
| ActionCodec: Returns self for method chaining. |
| |
| Note: |
| - New embodiment keys must not already exist in the current configuration. |
| - The model will automatically update max_action_dim if the new embodiment |
| has a larger action dimension. |
| - Existing embodiments will continue to work with their original configurations. |
| """ |
| if not isinstance(embodiment_config, dict): |
| raise TypeError(f"embodiment_config must be a dict, got {type(embodiment_config)}") |
| if len(embodiment_config) == 0: |
| raise ValueError("embodiment_config cannot be empty") |
|
|
| |
| overlapping_keys = set(embodiment_config.keys()) & set(self.config.embodiment_config.keys()) |
| if overlapping_keys: |
| raise ValueError(f"The following embodiment keys already exist and cannot be redefined: {overlapping_keys}") |
|
|
| self.encoder.expand_embodiment(embodiment_config) |
| self.decoder.expand_embodiment(embodiment_config) |
| self.config.embodiment_config.update(embodiment_config) |
| return self |
|
|
| def _encode( |
| self, |
| x: torch.Tensor, |
| embodiment_ids: torch.Tensor | int | None = None, |
| padding_mask: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| """Encode action sequences into latent representations. |
| |
| Args: |
| x (torch.Tensor): Action sequences to encode. Shape: (b, seq_len, max_action_dim). |
| Assumes that the action dimension is zero-padded to the max action dimension. |
| `seq_len` is supposed to be `int(duration * freq)` for each embodiment and padded to the max sequence length. |
| embodiment_ids (torch.Tensor | int): Embodiment IDs. Shape: (b,). |
| If int, the same embodiment ID is repeated for all sequences in the batch. |
| It specifies the embodiment to encode. |
| padding_mask (Optional[torch.Tensor], optional): Padding mask, where `False` values indicate padding. Shape: (b, seq_len). Defaults to None. |
| It is used to mask the padding tokens on `seq_len` dimension. |
| |
| Returns: |
| torch.Tensor: Encoded latent representations. Shape: (b, n_tokens_per_quantizer, z_dim). |
| """ |
| embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id |
| z_e = self.encoder(x, embodiment_ids, padding_mask) |
| return z_e |
|
|
| def _quantize( |
| self, z_e: torch.Tensor, return_perplexity: bool = True |
| ) -> Tuple[torch.Tensor, torch.Tensor, Union[float, List[float]], torch.Tensor]: |
| """Quantize encoded representations using vector quantization. |
| |
| Args: |
| z_e (torch.Tensor): Encoded latent representations to quantize. |
| Shape: (b, n_tokens_per_quantizer, z_dim). |
| return_perplexity (bool, optional): Whether to compute and return perplexity. |
| Defaults to True. |
| |
| Returns: |
| Tuple[torch.Tensor, torch.Tensor, Union[float, List[float]], torch.Tensor]: |
| A tuple containing: |
| - z_q (torch.Tensor): Quantized representations. |
| Shape: (b, n_tokens_per_quantizer, z_dim). |
| - indices (torch.Tensor): Quantization indices. |
| Shape: (b, n_tokens_per_quantizer) for VQ or (b, n_tokens_per_quantizer, n_quantizers) for RVQ. |
| - perplexity (Union[float, List[float]]): Codebook perplexity. |
| Float for single quantizer, List[float] for multiple quantizers. |
| - commit_loss (torch.Tensor): Commitment loss scalar tensor. |
| """ |
| if isinstance(self.vq, ResidualVectorQuantize): |
| z_q, indices, _, commitment_loss, codebook_loss = self.vq(z_e) |
| commit_loss = commitment_loss.mean() + codebook_loss.mean() |
| elif isinstance(self.vq, VectorQuantize): |
| z_q, indices, commit_loss = self.vq(z_e) |
| else: |
| raise NotImplementedError(f"VQ type {type(self.vq)} not implemented") |
|
|
| if return_perplexity: |
| if len(indices.size()) < 3: |
| indices = indices.unsqueeze(-1) |
| perplexity = [] |
| for k in range(indices.size(-1)): |
| this_indices = indices[:, :, k] |
| indices_count = torch.bincount(this_indices.view(-1), minlength=self.vq.codebook_size) |
| if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: |
| torch.distributed.all_reduce(indices_count) |
| this_avg_probs = indices_count.float() / indices_count.sum() |
| perplexity.append(((-(this_avg_probs * torch.log(this_avg_probs + 1e-10)).sum()).exp().item())) |
| else: |
| perplexity = 0 |
|
|
| return z_q, indices, perplexity, commit_loss |
|
|
| def _dequantize(self, indices: torch.Tensor) -> torch.Tensor: |
| """Dequantize token indices back to continuous latent representations. |
| |
| Args: |
| indices (torch.Tensor): Quantization indices. Shape depends on quantizer type: |
| - For VQ: (b, n_tokens) or (b, n_tokens, 1) |
| - For RVQ: (b, n_tokens_per_quantizer, n_quantizers) |
| |
| Returns: |
| torch.Tensor: Dequantized latent representations. |
| Shape: (b, n_tokens_per_quantizer, z_dim) |
| """ |
| if self.num_quantizers == 1: |
| if len(indices.size()) == 3: |
| indices = indices.squeeze(-1) |
| if isinstance(self.vq, ResidualVectorQuantize): |
| z_q = self.vq.from_codes(indices)[0] |
| elif isinstance(self.vq, VectorQuantize): |
| z_q = self.vq.get_output_from_indices(indices) |
| else: |
| raise NotImplementedError(f"VQ type {type(self.vq)} not implemented in _dequantize") |
| return z_q |
|
|
| def _decode( |
| self, z_q: torch.Tensor, embodiment_ids: torch.Tensor | int | None = None, durations: torch.Tensor | None = None |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Decode quantized latent representations into action sequences. |
| |
| Args: |
| z_q (torch.Tensor): Quantized latent representations. |
| Shape: (b, n_tokens_per_quantizer, z_dim). |
| embodiment_ids (Union[torch.Tensor, int, None], optional): Embodiment IDs. |
| Shape: (b,) if tensor. If int, the same embodiment ID is used for all |
| sequences. Defaults to None, which uses `self.default_embodiment_id`. |
| durations (torch.Tensor | None, optional): Duration of each action sequence in seconds. |
| Shape: (b,). If None, uses default duration from embodiment_config. |
| Defaults to None. |
| |
| Returns: |
| Tuple[torch.Tensor, torch.Tensor]: A tuple containing: |
| - x_recon (torch.Tensor): Reconstructed action sequences. |
| Shape: (b, seq_len, max_action_dim). |
| - padding_mask (torch.Tensor): Padding mask indicating valid timesteps. |
| Shape: (b, seq_len), where True indicates valid timesteps. |
| """ |
| embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id |
| x_recon, padding_mask = self.decoder(z_q, embodiment_ids, durations) |
| return x_recon, padding_mask |
|
|
| @torch.no_grad() |
| def encode( |
| self, |
| x: Union[np.ndarray, torch.Tensor], |
| embodiment_ids: Union[List[int], int, None] = None, |
| padding_mask: Union[List[bool], np.ndarray, torch.Tensor, None] = None, |
| **kwargs, |
| ) -> List[List[int]]: |
| """Encode action sequences into latent representations (token indices). |
| |
| This method converts action sequences into discrete token indices using the encoder |
| and vector quantizer. The input can be either a numpy array or torch tensor. |
| |
| Args: |
| x (Union[np.ndarray, torch.Tensor]): Action sequences to encode. |
| Shape: (b, seq_len, max_action_dim). |
| Assumes that the action dimension is zero-padded to the max action dimension. |
| `seq_len` is supposed to be `int(duration * freq)` for each embodiment and |
| padded to the max sequence length. |
| embodiment_ids (Union[List[int], int, None], optional): Embodiment IDs. |
| Shape: (b,) if list. If int, the same embodiment ID is repeated for all |
| sequences in the batch. It specifies the embodiment to encode. |
| Defaults to None, which uses `self.default_embodiment_id`. |
| padding_mask (Union[List[bool], np.ndarray, torch.Tensor, None], optional): |
| Padding mask, where `False` values indicate padding. Shape: (b, seq_len). |
| Defaults to None. It is used to mask the padding tokens on `seq_len` dimension. |
| **kwargs: Additional keyword arguments (currently unused, reserved for future use). |
| |
| Returns: |
| List[List[int]]: List of token sequences. Shape: (b, n_tokens), where n_tokens |
| is determined by the model configuration (typically `config.n_tokens`). |
| |
| Raises: |
| ValueError: If input shapes are invalid or incompatible with the model configuration. |
| TypeError: If input types are not supported. |
| |
| Examples: |
| >>> import numpy as np |
| >>> # Using numpy array |
| >>> x = np.random.randn(2, 10, 7).astype(np.float32) |
| >>> tokens = model.encode(x, embodiment_ids=[0, 0]) |
| >>> # Using torch tensor |
| >>> x_tensor = torch.randn(2, 10, 7) |
| >>> tokens = model.encode(x_tensor, embodiment_ids=[0, 0]) |
| """ |
| self.eval() |
|
|
| |
| if isinstance(x, np.ndarray): |
| if x.ndim != 3: |
| raise ValueError( |
| f"Expected 3D input array (batch, seq_len, action_dim), got {x.ndim}D array with shape {x.shape}" |
| ) |
| x_tensor = torch.tensor(x, dtype=self.dtype, device=self.device) |
| elif isinstance(x, torch.Tensor): |
| if x.ndim != 3: |
| raise ValueError( |
| f"Expected 3D tensor (batch, seq_len, action_dim), got {x.ndim}D tensor with shape {x.shape}" |
| ) |
| x_tensor = x.to(dtype=self.dtype, device=self.device) |
| else: |
| raise TypeError(f"Input x must be numpy.ndarray or torch.Tensor, got {type(x)}") |
|
|
| |
| batch_size = x_tensor.shape[0] |
| if batch_size == 0: |
| raise ValueError("Batch size must be at least 1") |
|
|
| |
| embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id |
| if isinstance(embodiment_ids, int): |
| if not 0 <= embodiment_ids < len(self.config.embodiment_config): |
| raise ValueError( |
| f"embodiment_id {embodiment_ids} is out of range [0, {len(self.config.embodiment_config)}). " |
| f"Available embodiment IDs: {list(range(len(self.config.embodiment_config)))}" |
| ) |
| embodiment_ids_tensor = torch.tensor([embodiment_ids] * batch_size, dtype=torch.long, device=self.device) |
| elif isinstance(embodiment_ids, list): |
| if len(embodiment_ids) != batch_size: |
| raise ValueError( |
| f"Length of embodiment_ids ({len(embodiment_ids)}) must match batch size ({batch_size})" |
| ) |
| for eid in embodiment_ids: |
| if not isinstance(eid, int) or not 0 <= eid < len(self.config.embodiment_config): |
| raise ValueError( |
| f"Invalid embodiment_id {eid}. Must be an integer in range [0, {len(self.config.embodiment_config)})" |
| ) |
| embodiment_ids_tensor = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device) |
| else: |
| raise TypeError(f"embodiment_ids must be int, List[int], or None, got {type(embodiment_ids)}") |
|
|
| |
| padding_mask_tensor = None |
| if padding_mask is not None: |
| if isinstance(padding_mask, (list, np.ndarray)): |
| padding_mask_tensor = torch.tensor(padding_mask, dtype=torch.bool, device=self.device) |
| elif isinstance(padding_mask, torch.Tensor): |
| padding_mask_tensor = padding_mask.to(dtype=torch.bool, device=self.device) |
| else: |
| raise TypeError( |
| f"padding_mask must be List[bool], np.ndarray, torch.Tensor, or None, got {type(padding_mask)}" |
| ) |
| if padding_mask_tensor.shape != (batch_size, x_tensor.shape[1]): |
| raise ValueError( |
| f"padding_mask shape {padding_mask_tensor.shape} does not match expected shape " |
| f"({batch_size}, {x_tensor.shape[1]})" |
| ) |
|
|
| with torch.no_grad(): |
| z_e = self._encode(x_tensor, embodiment_ids_tensor, padding_mask_tensor) |
| _, indices, _, _ = self._quantize(z_e, return_perplexity=False) |
|
|
| |
| if len(indices.size()) > 2: |
| codes_list = einops.rearrange(indices, "b n s -> b (s n)").cpu() |
| else: |
| codes_list = indices.cpu() |
|
|
| codes_list = codes_list.tolist() |
| return codes_list |
|
|
| @torch.no_grad() |
| def decode( |
| self, |
| tokens: Union[List[List[int]], np.ndarray, torch.Tensor], |
| embodiment_ids: Union[List[int], int, None] = None, |
| durations: Union[List[float], np.ndarray, torch.Tensor, None] = None, |
| **kwargs, |
| ) -> Tuple[np.ndarray, np.ndarray]: |
| """Decode token sequences into action sequences. |
| |
| This method reconstructs action sequences from discrete token indices using the |
| vector quantizer and decoder. The input tokens can be a list of lists, numpy array, |
| or torch tensor. |
| |
| Args: |
| tokens (Union[List[List[int]], np.ndarray, torch.Tensor]): Token sequences to decode. |
| Shape: (b, n_tokens), where n_tokens must be divisible by `n_tokens_per_quantizer`. |
| For RVQ, tokens are interleaved: [q0_t0, q1_t0, ..., qN_t0, q0_t1, ...]. |
| embodiment_ids (Union[List[int], int, None], optional): Embodiment IDs. |
| Shape: (b,) if list. If int, the same embodiment ID is repeated for all |
| sequences in the batch. It specifies the embodiment to decode. |
| Defaults to None, which uses `self.default_embodiment_id`. |
| durations (Union[List[float], np.ndarray, torch.Tensor, None], optional): |
| Duration of each action sequence in seconds. Shape: (b,). |
| If None, the duration is inferred from the default values in `embodiment_config`. |
| Defaults to None. |
| **kwargs: Additional keyword arguments (currently unused, reserved for future use). |
| |
| Returns: |
| Tuple[np.ndarray, np.ndarray]: A tuple containing: |
| - reconstructed_actions: Reconstructed action sequences. |
| Shape: (b, seq_len, max_action_dim). |
| - padding_mask: Padding mask indicating valid timesteps. |
| Shape: (b, seq_len), where True indicates valid timesteps. |
| |
| Raises: |
| ValueError: If token sequence length is invalid or incompatible with the model configuration. |
| TypeError: If input types are not supported. |
| |
| Examples: |
| >>> # Using list of lists |
| >>> tokens = [[1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16]] |
| >>> actions, mask = model.decode(tokens, embodiment_ids=[0, 0]) |
| >>> # Using numpy array |
| >>> tokens_np = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]) |
| >>> actions, mask = model.decode(tokens_np, embodiment_ids=[0, 0]) |
| >>> # Using torch tensor |
| >>> tokens_tensor = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) |
| >>> actions, mask = model.decode(tokens_tensor, embodiment_ids=[0, 0]) |
| """ |
| self.eval() |
|
|
| |
| if isinstance(tokens, list): |
| if not all(isinstance(seq, list) for seq in tokens): |
| raise TypeError("If tokens is a list, all elements must be lists") |
| if len(tokens) == 0: |
| raise ValueError("Tokens list cannot be empty") |
| if not all(isinstance(val, (int, np.integer)) for seq in tokens for val in seq): |
| raise TypeError("All token values must be integers") |
| tokens_tensor = torch.tensor(tokens, dtype=torch.long, device=self.device) |
| elif isinstance(tokens, np.ndarray): |
| if tokens.ndim != 2: |
| raise ValueError( |
| f"Expected 2D array (batch, n_tokens), got {tokens.ndim}D array with shape {tokens.shape}" |
| ) |
| if not np.issubdtype(tokens.dtype, np.integer): |
| raise TypeError(f"Tokens array must have integer dtype, got {tokens.dtype}") |
| tokens_tensor = torch.tensor(tokens, dtype=torch.long, device=self.device) |
| elif isinstance(tokens, torch.Tensor): |
| if tokens.ndim != 2: |
| raise ValueError( |
| f"Expected 2D tensor (batch, n_tokens), got {tokens.ndim}D tensor with shape {tokens.shape}" |
| ) |
| if not tokens.dtype.is_integer: |
| raise TypeError(f"Tokens tensor must have integer dtype, got {tokens.dtype}") |
| tokens_tensor = tokens.to(dtype=torch.long, device=self.device) |
| else: |
| raise TypeError(f"tokens must be List[List[int]], np.ndarray, or torch.Tensor, got {type(tokens)}") |
|
|
| batch_size, n_tokens = tokens_tensor.shape |
| if batch_size == 0: |
| raise ValueError("Batch size must be at least 1") |
| if n_tokens == 0: |
| raise ValueError("Token sequence length must be at least 1") |
|
|
| |
| if n_tokens % self.n_tokens_per_quantizer != 0: |
| raise ValueError( |
| f"Token sequence length ({n_tokens}) must be divisible by tokens per quantizer " |
| f"({self.n_tokens_per_quantizer}). Total tokens: {n_tokens}, " |
| f"Expected multiple of: {self.n_tokens_per_quantizer}. " |
| f"Number of quantizers: {self.num_quantizers}, Total tokens per sequence: {self.config.n_tokens}" |
| ) |
|
|
| |
| if tokens_tensor.min() < 0 or tokens_tensor.max() >= self.vocab_size: |
| raise ValueError( |
| f"Token values must be in range [0, {self.vocab_size}), " |
| f"got range [{tokens_tensor.min().item()}, {tokens_tensor.max().item()}]" |
| ) |
|
|
| |
| embodiment_ids = embodiment_ids if embodiment_ids is not None else self.default_embodiment_id |
| if isinstance(embodiment_ids, int): |
| if not 0 <= embodiment_ids < len(self.config.embodiment_config): |
| raise ValueError( |
| f"embodiment_id {embodiment_ids} is out of range [0, {len(self.config.embodiment_config)}). " |
| f"Available embodiment IDs: {list(range(len(self.config.embodiment_config)))}" |
| ) |
| embodiment_ids_tensor = torch.tensor([embodiment_ids] * batch_size, dtype=torch.long, device=self.device) |
| elif isinstance(embodiment_ids, list): |
| if len(embodiment_ids) != batch_size: |
| raise ValueError( |
| f"Length of embodiment_ids ({len(embodiment_ids)}) must match batch size ({batch_size})" |
| ) |
| for eid in embodiment_ids: |
| if not isinstance(eid, int) or not 0 <= eid < len(self.config.embodiment_config): |
| raise ValueError( |
| f"Invalid embodiment_id {eid}. Must be an integer in range [0, {len(self.config.embodiment_config)})" |
| ) |
| embodiment_ids_tensor = torch.tensor(embodiment_ids, dtype=torch.long, device=self.device) |
| else: |
| raise TypeError(f"embodiment_ids must be int, List[int], or None, got {type(embodiment_ids)}") |
|
|
| |
| durations_tensor = None |
| if durations is not None: |
| if isinstance(durations, (list, np.ndarray)): |
| durations_tensor = torch.tensor(durations, dtype=torch.float32, device=self.device) |
| elif isinstance(durations, torch.Tensor): |
| durations_tensor = durations.to(dtype=torch.float32, device=self.device) |
| else: |
| raise TypeError( |
| f"durations must be List[float], np.ndarray, torch.Tensor, or None, got {type(durations)}" |
| ) |
| if durations_tensor.ndim != 1: |
| raise ValueError( |
| f"durations must be 1D, got {durations_tensor.ndim}D with shape {durations_tensor.shape}" |
| ) |
| if len(durations_tensor) != batch_size: |
| raise ValueError(f"Length of durations ({len(durations_tensor)}) must match batch size ({batch_size})") |
| if (durations_tensor <= 0).any(): |
| raise ValueError("All durations must be positive") |
|
|
| |
| indices = einops.rearrange(tokens_tensor, "b (n m) -> b m n", m=self.n_tokens_per_quantizer) |
|
|
| with torch.no_grad(): |
| z_q = self._dequantize(indices) |
| x_recon, padding_mask = self._decode(z_q, embodiment_ids_tensor, durations_tensor) |
|
|
| return x_recon.float().cpu().numpy(), padding_mask.float().cpu().numpy() |
|
|
| def forward( |
| self, |
| x: Union[torch.Tensor, np.ndarray], |
| embodiment_ids: Union[torch.Tensor, int, List[int], None] = None, |
| padding_mask: Union[torch.Tensor, List[bool], np.ndarray, None] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Forward pass through the full ActionCodec pipeline. |
| |
| This method performs encoding, quantization, and decoding in a single forward pass. |
| It is primarily used during training to compute reconstruction loss and commitment loss. |
| Both numpy arrays and torch tensors are supported as input. |
| |
| Args: |
| x (Union[torch.Tensor, np.ndarray]): Action sequences to process. |
| Shape: (b, seq_len, max_action_dim). |
| embodiment_ids (Union[torch.Tensor, int, List[int], None], optional): |
| Embodiment IDs. Shape: (b,) if tensor or list. If int, same ID for all sequences. |
| Defaults to None, which uses `self.default_embodiment_id`. |
| padding_mask (Union[torch.Tensor, List[bool], np.ndarray, None], optional): |
| Padding mask. Shape: (b, seq_len). Defaults to None. |
| |
| Returns: |
| Tuple[torch.Tensor, torch.Tensor]: A tuple containing: |
| - x_recon (torch.Tensor): Reconstructed action sequences. |
| Shape: (b, seq_len, max_action_dim). |
| - recon_mask (torch.Tensor): Reconstruction mask indicating valid timesteps. |
| Shape: (b, seq_len), where True indicates valid timesteps. |
| |
| Note: |
| - For inference use cases, prefer using `encode()` and `decode()` methods separately. |
| - If you need token indices, use the `encode()` method instead. |
| """ |
| |
| if isinstance(x, np.ndarray): |
| x = torch.tensor(x, dtype=self.dtype, device=self.device) |
|
|
| |
| if isinstance(embodiment_ids, list): |
| embodiment_ids = torch.tensor(embodiment_ids, device=x.device, dtype=torch.long) |
| elif isinstance(embodiment_ids, int): |
| |
| pass |
|
|
| |
| if isinstance(padding_mask, (list, np.ndarray)): |
| padding_mask = torch.tensor(padding_mask, device=x.device, dtype=torch.bool) |
|
|
| |
| z_e = self._encode(x, embodiment_ids, padding_mask) |
| z_q, indices, perplexity, commit_loss = self._quantize(z_e, return_perplexity=True) |
| x_recon, recon_mask = self._decode(z_q, embodiment_ids) |
|
|
| return x_recon, recon_mask |
|
|
|
|
| AutoModel.register(ActionCodecConfig, ActionCodec) |
|
|
| __all__ = ["ActionCodec"] |
|
|
|
|
| if __name__ == "__main__": |
| print("=== ActionCodec Comprehensive Test ===\n") |
|
|
| |
| initial_config = { |
| "robot_A": {"action_dim": 7, "freq": 10, "duration": 1, "description": "Robot A"}, |
| } |
|
|
| |
| config = ActionCodecConfig( |
| embodiment_config=initial_config, |
| n_tokens=16, |
| n_quantizers=4, |
| vq_type="rvq", |
| vq_codebook_size=256, |
| encoder_dim=128, |
| decoder_dim=128, |
| ) |
|
|
| |
| latent_seq_len = int(config.n_tokens // config.n_quantizers) |
| print(f"Config: {config.n_quantizers} quantizers, {latent_seq_len} latent vectors per sequence.") |
|
|
| codec = ActionCodec(config) |
| codec.eval() |
|
|
| |
| print("\n--- Test 1: Basic Encode/Decode ---") |
| batch_size = 2 |
| seq_len_A = 10 |
|
|
| |
| x = np.random.randn(batch_size, seq_len_A, 7).astype(np.float32) |
| |
| padding_mask = np.ones((batch_size, seq_len_A), dtype=bool) |
| padding_mask[1, 5:] = False |
|
|
| embodiment_ids = [0, 0] |
|
|
| |
| codes = codec.encode(x, embodiment_ids, padding_mask) |
| print(f"Encoded codes shape (list length): {len(codes)} x {len(codes[0])}") |
|
|
| |
| assert len(codes[0]) == config.n_tokens, f"Expected {config.n_tokens} tokens, got {len(codes[0])}" |
|
|
| |
| x_recon, recon_mask = codec.decode(codes, embodiment_ids) |
| print(f"Reconstructed shape: {x_recon.shape}") |
| print(f"Recon mask shape: {recon_mask.shape}") |
|
|
| assert x_recon.shape == (batch_size, seq_len_A, 7) |
|
|
| |
| print("\n--- Test 2: Dynamic Expansion ---") |
| new_robot_config = {"robot_B": {"action_dim": 10, "freq": 20, "duration": 1, "description": "Robot B (Larger)"}} |
|
|
| print("Expanding codec to include Robot B (10 dims, 20Hz)...") |
| codec.expand_embodiment(new_robot_config) |
|
|
| assert codec.encoder.max_action_dim == 10 |
| assert codec.decoder.max_action_dim == 10 |
| print("✅ Expansion successful.") |
|
|
| |
| print("\n--- Test 3: Mixed Batch Inference ---") |
|
|
| |
| |
| |
| |
|
|
| batch_x_mixed = np.zeros((2, 20, 10), dtype=np.float32) |
|
|
| |
| data_A = np.random.randn(10, 7) |
| batch_x_mixed[0, :10, :7] = data_A |
|
|
| |
| data_B = np.random.randn(20, 10) |
| batch_x_mixed[1, :20, :10] = data_B |
|
|
| |
| |
| mixed_ids = [0, 1] |
|
|
| |
| mixed_mask = np.zeros((2, 20), dtype=bool) |
| mixed_mask[0, :10] = True |
| mixed_mask[1, :20] = True |
|
|
| print("Encoding mixed batch...") |
| mixed_codes = codec.encode(batch_x_mixed, mixed_ids, mixed_mask) |
|
|
| print("Decoding mixed batch...") |
| |
| durations = [1, 1] |
| x_recon_mixed, dec_mask_mixed = codec.decode(mixed_codes, mixed_ids, durations) |
|
|
| print(f"Mixed Recon Shape: {x_recon_mixed.shape}") |
|
|
| |
| |
| valid_A = dec_mask_mixed[0].sum() |
| valid_B = dec_mask_mixed[1].sum() |
|
|
| print(f"Valid steps detected by Decoder: Robot A={valid_A}, Robot B={valid_B}") |
|
|
| assert valid_A == 10 |
| assert valid_B == 20 |
|
|
| |
| |
| |
| print("✅ Mixed batch processed successfully.") |
|
|
| print("\n✨ All systems go.") |
|
|