| import copy |
| from typing import Any, Dict |
|
|
| from transformers import AutoConfig, PretrainedConfig |
|
|
|
|
| class ActionCodecConfig(PretrainedConfig): |
| model_type = "action_codec" |
|
|
| def __init__( |
| self, |
| embodiment_config: Dict[str, Any] = None, |
| n_tokens: int = 16, |
| n_quantizers: int = 1, |
| z_dim: int = 512, |
| vq_type: str = "vq", |
| vq_codebook_size: int = 2048, |
| vq_commitment_weight: float = 0.25, |
| vq_decay: float = 0.99, |
| vq_kmeans_init: bool = True, |
| vq_threshold_ema_dead_code: int = 2, |
| vq_quantizer_dropout: float = 0.25, |
| encoder_dim: int = 256, |
| encoder_n_layers: int = 6, |
| encoder_n_heads: int = 8, |
| encoder_add_self_attn: bool = False, |
| encoder_add_causal_mask: bool = False, |
| encoder_pos_encoding_type: str = "fourier", |
| decoder_dim: int = 256, |
| decoder_n_layers: int = 6, |
| decoder_n_heads: int = 8, |
| decoder_add_self_attn: bool = False, |
| decoder_add_causal_mask: bool = False, |
| decoder_pos_encoding_type: str = "fourier", |
| decoder_cls_size: int = 1, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| if embodiment_config is None: |
| default_config = { |
| "franka_libero_20hz": { |
| "action_dim": 7, |
| "freq": 20, |
| "duration": 1, |
| "description": "20Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).", |
| }, |
| "widowx_bridge_5hz": { |
| "action_dim": 7, |
| "freq": 5, |
| "duration": 1, |
| "description": "5Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).", |
| }, |
| "franka_droid_15hz": { |
| "action_dim": 7, |
| "freq": 15, |
| "duration": 1, |
| "description": "15Hz 7-dim action for 1s. Delta eef position (xyz), orientation (rpy), and gripper position (1 open/0 close).", |
| }, |
| } |
| self.embodiment_config = copy.deepcopy(default_config) |
| else: |
| self.embodiment_config = copy.deepcopy(embodiment_config) |
|
|
| self.n_tokens = n_tokens |
| self.n_quantizers = n_quantizers |
| self.z_dim = z_dim |
|
|
| self.encoder_dim = encoder_dim |
| self.encoder_n_layers = encoder_n_layers |
| self.encoder_n_heads = encoder_n_heads |
| self.encoder_add_self_attn = encoder_add_self_attn |
| self.encoder_add_causal_mask = encoder_add_causal_mask |
| self.encoder_pos_encoding_type = encoder_pos_encoding_type |
|
|
| self.decoder_dim = decoder_dim |
| self.decoder_n_layers = decoder_n_layers |
| self.decoder_n_heads = decoder_n_heads |
| self.decoder_add_self_attn = decoder_add_self_attn |
| self.decoder_add_causal_mask = decoder_add_causal_mask |
| self.decoder_pos_encoding_type = decoder_pos_encoding_type |
| self.decoder_cls_size = decoder_cls_size |
|
|
| self.vq_type = vq_type |
| self.vq_codebook_size = vq_codebook_size |
| self.vq_commitment_weight = vq_commitment_weight |
| self.vq_decay = vq_decay |
| self.vq_kmeans_init = vq_kmeans_init |
| self.vq_threshold_ema_dead_code = vq_threshold_ema_dead_code |
| self.vq_quantizer_dropout = vq_quantizer_dropout |
|
|
|
|
| class ActionCodecConfigOld(PretrainedConfig): |
| model_type = "action_codec" |
|
|
| def __init__( |
| self, |
| horizon: int = 20, |
| action_dim: int = 7, |
| action_encoding: str = "independent_v2", |
| horizon_patch_size: int = 1, |
| encoder_class: str = "action_codec.modules.perceiver.PerceiverEncoder", |
| decoder_class: str = "action_codec.modules.perceiver.PerceiverDecoder", |
| vq_class: str = "vector_quantize_pytorch.VectorQuantize", |
| encoder_kwargs: Dict[str, Any] = None, |
| decoder_kwargs: Dict[str, Any] = None, |
| vq_kwargs: Dict[str, Any] = None, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.horizon = horizon |
| self.action_dim = action_dim |
| self.action_encoding = action_encoding |
| self.horizon_patch_size = horizon_patch_size |
| self.encoder_class = encoder_class |
| self.decoder_class = decoder_class |
| self.vq_class = vq_class |
| self.encoder_kwargs = ( |
| dict(encoder_kwargs) |
| if encoder_kwargs is not None |
| else { |
| "dim": 384, |
| "in_len": horizon, |
| "out_len": 16, |
| "num_layers": 12, |
| "num_heads": 4, |
| "output_round": -1.0, |
| } |
| ) |
| self.decoder_kwargs = ( |
| dict(decoder_kwargs) |
| if decoder_kwargs is not None |
| else { |
| "dim": 384, |
| "in_len": 16, |
| "out_len": horizon, |
| "num_layers": 12, |
| "num_heads": 4, |
| } |
| ) |
| self.vq_kwargs = ( |
| dict(vq_kwargs) |
| if vq_kwargs is not None |
| else { |
| "dim": 512, |
| "codebook_size": 2048, |
| "kmeans_init": True, |
| "kmeans_iters": 10, |
| "decay": 0.99, |
| "commitment_weight": 0.25, |
| "rotation_trick": False, |
| "threshold_ema_dead_code": 2, |
| "use_cosine_sim": False, |
| "codebook_diversity_loss_weight": 0.0, |
| } |
| ) |
|
|
|
|
| class BPEActionCodecConfig(PretrainedConfig): |
| model_type = "bpe_action_codec" |
|
|
| def __init__( |
| self, |
| horizon: int = 20, |
| action_dim: int = 7, |
| action_encoding: str = "independent_v2", |
| horizon_patch_size: int = 1, |
| encoder_class: str = "action_codec.modules.perceiver.PerceiverEncoder", |
| decoder_class: str = "action_codec.modules.perceiver.PerceiverDecoder", |
| vq_class: str = "vector_quantize_pytorch.VectorQuantize", |
| encoder_kwargs: Dict[str, Any] = None, |
| decoder_kwargs: Dict[str, Any] = None, |
| vq_kwargs: Dict[str, Any] = None, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.horizon = horizon |
| self.action_dim = action_dim |
| self.action_encoding = action_encoding |
| self.horizon_patch_size = horizon_patch_size |
| self.encoder_class = encoder_class |
| self.decoder_class = decoder_class |
| self.vq_class = vq_class |
| self.encoder_kwargs = ( |
| dict(encoder_kwargs) |
| if encoder_kwargs is not None |
| else { |
| "dim": 384, |
| "in_len": horizon, |
| "out_len": 16, |
| "num_layers": 12, |
| "num_heads": 4, |
| "output_round": -1.0, |
| } |
| ) |
| self.decoder_kwargs = ( |
| dict(decoder_kwargs) |
| if decoder_kwargs is not None |
| else { |
| "dim": 384, |
| "in_len": 16, |
| "out_len": horizon, |
| "num_layers": 12, |
| "num_heads": 4, |
| } |
| ) |
| self.vq_kwargs = ( |
| dict(vq_kwargs) |
| if vq_kwargs is not None |
| else { |
| "dim": 512, |
| "codebook_size": 2048, |
| "kmeans_init": True, |
| "kmeans_iters": 10, |
| "decay": 0.99, |
| "commitment_weight": 0.25, |
| "rotation_trick": False, |
| "threshold_ema_dead_code": 2, |
| "use_cosine_sim": False, |
| "codebook_diversity_loss_weight": 0.0, |
| } |
| ) |
|
|
|
|
| AutoConfig.register("action_codec", ActionCodecConfig) |
| AutoConfig.register("bpe_action_codec", BPEActionCodecConfig) |
|
|
| ActionCodecConfig.register_for_auto_class() |
|
|
| __all__ = ["ActionCodecConfig", "BPEActionCodecConfig"] |
|
|