| import re |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from diffusers.schedulers.scheduling_ddpm import DDPMScheduler |
| from diffusers.schedulers.scheduling_dpmsolver_multistep import \ |
| DPMSolverMultistepScheduler |
|
|
| from models.hub_mixin import CompatiblePyTorchModelHubMixin |
| from models.rdt.model import RDT |
|
|
|
|
| class RDTRunner( |
| nn.Module, |
| CompatiblePyTorchModelHubMixin, |
| repo_url="https://huggingface.co/robotics-diffusion-transformer/rdt-1b" |
| ): |
| def __init__(self, *, action_dim, pred_horizon, config, |
| lang_token_dim, img_token_dim, state_token_dim, |
| max_lang_cond_len, img_cond_len, lang_pos_embed_config=None, |
| img_pos_embed_config=None, dtype=torch.bfloat16): |
| super(RDTRunner, self).__init__() |
| |
| hidden_size = config['rdt']['hidden_size'] |
| self.model = RDT( |
| output_dim=action_dim, |
| horizon=pred_horizon, |
| hidden_size=hidden_size, |
| depth=config['rdt']['depth'], |
| num_heads=config['rdt']['num_heads'], |
| max_lang_cond_len=max_lang_cond_len, |
| img_cond_len=img_cond_len, |
| lang_pos_embed_config=lang_pos_embed_config, |
| img_pos_embed_config=img_pos_embed_config, |
| dtype=dtype, |
| ) |
|
|
| |
| self.lang_adaptor = self.build_condition_adapter( |
| config['lang_adaptor'], |
| in_features=lang_token_dim, |
| out_features=hidden_size |
| ) |
| self.img_adaptor = self.build_condition_adapter( |
| config['img_adaptor'], |
| in_features=img_token_dim, |
| out_features=hidden_size |
| ) |
| |
| self.state_adaptor = self.build_condition_adapter( |
| config['state_adaptor'], |
| in_features=state_token_dim * 2, |
| out_features=hidden_size |
| ) |
| |
| |
| noise_scheduler_config = config['noise_scheduler'] |
| self.noise_scheduler = DDPMScheduler( |
| num_train_timesteps=noise_scheduler_config['num_train_timesteps'], |
| beta_schedule=noise_scheduler_config['beta_schedule'], |
| prediction_type=noise_scheduler_config['prediction_type'], |
| clip_sample=noise_scheduler_config['clip_sample'], |
| ) |
| self.noise_scheduler_sample = DPMSolverMultistepScheduler( |
| num_train_timesteps=noise_scheduler_config['num_train_timesteps'], |
| beta_schedule=noise_scheduler_config['beta_schedule'], |
| prediction_type=noise_scheduler_config['prediction_type'], |
| ) |
|
|
| self.num_train_timesteps = noise_scheduler_config['num_train_timesteps'] |
| self.num_inference_timesteps = noise_scheduler_config['num_inference_timesteps'] |
| self.prediction_type = noise_scheduler_config['prediction_type'] |
|
|
| self.pred_horizon = pred_horizon |
| self.action_dim = action_dim |
|
|
| print("Diffusion params: %e" % sum( |
| [p.numel() for p in self.model.parameters()] + |
| [p.numel() for p in self.lang_adaptor.parameters()] + |
| [p.numel() for p in self.img_adaptor.parameters()] + |
| [p.numel() for p in self.state_adaptor.parameters()])) |
| |
| def build_condition_adapter( |
| self, projector_type, in_features, out_features): |
| projector = None |
| if projector_type == 'linear': |
| projector = nn.Linear(in_features, out_features) |
| else: |
| mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) |
| if mlp_gelu_match: |
| mlp_depth = int(mlp_gelu_match.group(1)) |
| modules = [nn.Linear(in_features, out_features)] |
| for _ in range(1, mlp_depth): |
| modules.append(nn.GELU(approximate="tanh")) |
| modules.append(nn.Linear(out_features, out_features)) |
| projector = nn.Sequential(*modules) |
|
|
| if projector is None: |
| raise ValueError(f'Unknown projector type: {projector_type}') |
|
|
| return projector |
| |
| def adapt_conditions(self, lang_tokens, img_tokens, state_tokens): |
| ''' |
| lang_tokens: (batch_size, lang_len, lang_token_dim) |
| img_tokens: (batch_size, img_len, img_token_dim) |
| state_tokens: (batch_size, state_len, state_token_dim) |
| |
| return: adpated (..., hidden_size) for all input tokens |
| ''' |
| adpated_lang = self.lang_adaptor(lang_tokens) |
| adpated_img = self.img_adaptor(img_tokens) |
| adpated_state = self.state_adaptor(state_tokens) |
|
|
| return adpated_lang, adpated_img, adpated_state |
|
|
| def conditional_sample(self, lang_cond, lang_attn_mask, img_cond, |
| state_traj, action_mask, ctrl_freqs): |
| ''' |
| lang_cond: language conditional data, (batch_size, lang_len, hidden_size). |
| lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens, |
| which should be True-False bool tensor. |
| img_cond: image conditional data, (batch_size, img_len, hidden_size). |
| state_traj: (batch_size, 1, hidden_size), state trajectory. |
| action_mask: (batch_size, 1, action_dim), a 0-1 **float** tensor |
| indicating the valid action dimensions. |
| ctrl_freqs: (batch_size,), control frequency for each sample. |
| |
| return: (batch_size, horizon, action_dim) |
| ''' |
| device = state_traj.device |
| dtype = state_traj.dtype |
| noisy_action = torch.randn( |
| size=(state_traj.shape[0], self.pred_horizon, self.action_dim), |
| dtype=dtype, device=device) |
| action_mask = action_mask.expand(-1, self.pred_horizon, -1) |
| |
| |
| self.noise_scheduler_sample.set_timesteps(self.num_inference_timesteps) |
| |
| for t in self.noise_scheduler_sample.timesteps: |
| |
| action_traj = torch.cat([noisy_action, action_mask], dim=2) |
| action_traj = self.state_adaptor(action_traj) |
| state_action_traj = torch.cat([state_traj, action_traj], dim=1) |
| |
| |
| model_output = self.model(state_action_traj, ctrl_freqs, |
| t.unsqueeze(-1).to(device), |
| lang_cond, img_cond, lang_mask=lang_attn_mask) |
| |
| |
| noisy_action = self.noise_scheduler_sample.step( |
| model_output, t, noisy_action).prev_sample |
| noisy_action = noisy_action.to(state_traj.dtype) |
| |
| |
| noisy_action = noisy_action * action_mask |
|
|
| return noisy_action |
| |
| |
| def compute_loss(self, lang_tokens, lang_attn_mask, img_tokens, |
| state_tokens, action_gt, action_mask, ctrl_freqs |
| ) -> torch.Tensor: |
| ''' |
| lang_tokens: (batch_size, lang_len, lang_token_dim) |
| lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens, |
| which should be True-False bool tensor. |
| img_tokens: (batch_size, img_len, img_token_dim) |
| state_tokens: (batch_size, 1, state_token_dim) |
| action_gt: (batch_size, horizon, state_token_dim), ground-truth actions for supervision |
| action_mask: (batch_size, 1, state_token_dim), a 0-1 **float** tensor. |
| ctrl_freqs: (batch_size,), control frequency for each sample. |
| |
| return: loss_value, a scalar tensor |
| ''' |
| batch_size = lang_tokens.shape[0] |
| device = lang_tokens.device |
|
|
| |
| noise = torch.randn( |
| action_gt.shape, dtype=action_gt.dtype, device=device |
| ) |
| |
| timesteps = torch.randint( |
| 0, self.num_train_timesteps, |
| (batch_size,), device=device |
| ).long() |
| |
| |
| noisy_action = self.noise_scheduler.add_noise( |
| action_gt, noise, timesteps) |
| |
| |
| state_action_traj = torch.cat([state_tokens, noisy_action], dim=1) |
| |
| action_mask = action_mask.expand(-1, state_action_traj.shape[1], -1) |
| state_action_traj = torch.cat([state_action_traj, action_mask], dim=2) |
| |
| lang_cond, img_cond, state_action_traj = self.adapt_conditions( |
| lang_tokens, img_tokens, state_action_traj) |
| |
| pred = self.model(state_action_traj, ctrl_freqs, |
| timesteps, lang_cond, img_cond, |
| lang_mask=lang_attn_mask) |
|
|
| pred_type = self.prediction_type |
| if pred_type == 'epsilon': |
| target = noise |
| elif pred_type == 'sample': |
| target = action_gt |
| else: |
| raise ValueError(f"Unsupported prediction type {pred_type}") |
|
|
| loss = F.mse_loss(pred, target) |
| return loss |
| |
| |
| def predict_action(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, |
| action_mask, ctrl_freqs): |
| ''' |
| lang_tokens: (batch_size, lang_len, lang_token_dim) |
| lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens, |
| which should be True-False bool tensor. |
| img_tokens: (batch_size, img_len, img_token_dim) |
| state_tokens: (batch_size, 1, state_token_dim) |
| action_mask: (batch_size, 1, action_dim), |
| which should be a 0-1 **float** tensor. |
| ctrl_freqs: (batch_size,), control frequency for each sample. |
| |
| return: (batch_size, horizon, action_dim), predicted action sequence |
| ''' |
| |
| state_tokens = torch.cat([state_tokens, action_mask], dim=2) |
| lang_cond, img_cond, state_traj = self.adapt_conditions( |
| lang_tokens, img_tokens, state_tokens) |
| |
| |
| action_pred = self.conditional_sample( |
| lang_cond, lang_attn_mask, img_cond, |
| state_traj, action_mask, ctrl_freqs, |
| ) |
| |
| return action_pred |
| |
| def forward(self, *args, **kwargs) -> torch.Tensor: |
| return self.compute_loss(*args, **kwargs) |
|
|