import numpy as np import torch from torch import nn from torch.nn import functional as F def get_patches_fast(images, cfg): from einops import rearrange batch_size, height, width, channels = images.shape patch_size = cfg.patch_size ## n_patches = 8 patches = rearrange(images[:,:,:,:3], 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size) if channels > 3: ## History stacking in the channel dimension for observations only, not goal images. patches = rearrange(images, 'b (h p1) (w p2) (c hs) -> b (h w hs) (p1 p2 c)', p1 = patch_size, p2 = patch_size, hs=cfg.policy.obs_stacking) ## Stack the history in the channel dimension return patches def calc_positional_embeddings(sequence_length, d): result = torch.ones(sequence_length, d) for i in range(sequence_length): for j in range(d): result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d))) return result class Head(nn.Module): """ one head of self-attention """ def __init__(self, head_size, n_embd, dropout): super().__init__() self.key = nn.Linear(n_embd, head_size, bias=False) self.query = nn.Linear(n_embd, head_size, bias=False) self.value = nn.Linear(n_embd, head_size, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): B,T,C = x.shape # TODO: ## Provide the block masking logic for the attention head k = self.key(x) q = self.query(x) wei = q @ k.transpose(-2,-1) * C**-0.5 #wei = wei.masked_fill(mask == 0, float('-inf')) if mask is not None: wei = wei.masked_fill(mask == 0, float('-inf')) wei = F.softmax(wei, dim=-1) wei = self.dropout(wei) v = self.value(x) out = wei @ v return out class MultiHeadAttention(nn.Module): def __init__(self, num_heads, head_size, n_embd, dropout): super().__init__() self.heads = nn.ModuleList([Head(head_size, n_embd=n_embd, dropout=dropout) for _ in range(num_heads)]) self.proj = nn.Linear(n_embd, n_embd) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): with torch.profiler.record_function("Self-Attention"): out = torch.cat([h(x, mask) for h in self.heads], dim=-1) out = self.dropout(self.proj(out)) return out class FeedFoward(nn.Module): def __init__(self, n_embd, dropout): super().__init__() self.net = nn.Sequential( nn.Linear(n_embd, 4 * n_embd), nn.ReLU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(dropout), ) def forward(self, x): return self.net(x) class Block(nn.Module): def __init__(self, n_embd, n_head, dropout): super().__init__() head_size = n_embd // n_head self.sa = MultiHeadAttention(n_head, head_size, n_embd=n_embd, dropout=dropout) self.ffwd = FeedFoward(n_embd, dropout) self.ln1 = nn.LayerNorm(n_embd) self.ln2 = nn.LayerNorm(n_embd) def forward(self, x, mask=None): x = x + self.sa(self.ln1(x), mask) x = x + self.ffwd(self.ln2(x)) return x class GRP(nn.Module): def __init__(self, cfg, mlp_ratio=4): super(GRP, self).__init__() self._cfg = cfg chars = cfg.dataset.chars_list cfg.vocab_size = len(chars) # TODO: ## Provide the logic for the GRP network # 4) Transformer encoder blocks # 5) Classification MLPk # 1) Embeddings # Calculate patch dimension: patch_size * patch_size * 3 (RGB) patch_dim = cfg.patch_size * cfg.patch_size * 3 self.patch_embedding = nn.Linear(patch_dim, cfg.n_embd) # Token embedding for text goals (if not using T5) # Check if dataset config exists and has encode_with_t5, else assume False or handle safely use_t5 = False if hasattr(cfg, 'dataset') and hasattr(cfg.dataset, 'encode_with_t5'): use_t5 = cfg.dataset.encode_with_t5 if not use_t5: self.token_embedding_table = nn.Embedding(cfg.vocab_size, cfg.n_embd) # Learnable positional embeddings or fixed buffer # Calculate maximum sequence length: CLS + text tokens + separator + goal img patches + obs patches num_patches_per_image = cfg.n_patches * cfg.n_patches max_seq_len = 1 + cfg.max_block_size + 1 + num_patches_per_image + num_patches_per_image * cfg.policy.obs_stacking # Using a fixed buffer as per the helper function provided pos_emb = calc_positional_embeddings(max_seq_len, cfg.n_embd) self.register_buffer('pos_embedding', pos_emb) # Special Tokens self.cls_token = nn.Parameter(torch.randn(1, 1, cfg.n_embd)) self.goal_token = nn.Parameter(torch.randn(1, 1, cfg.n_embd)) # Dropout self.dropout = nn.Dropout(cfg.dropout) # 2) Transformer encoder blocks # Corrected: using cfg.n_blocks instead of cfg.n_layer self.blocks = nn.Sequential(*[ Block(cfg.n_embd, n_head=cfg.n_head, dropout=cfg.dropout) for _ in range(cfg.n_blocks) ]) self.ln_f = nn.LayerNorm(cfg.n_embd) # Final layer norm # 3) Classification MLP / Action Head # Corrected: using cfg.action_dim (root) instead of cfg.env.action_dim output_dim = cfg.action_dim * cfg.policy.action_stacking self.lm_head = nn.Linear(cfg.n_embd, output_dim) self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, images, goals_txt, goal_imgs, targets=None, pose=None, mask_=False): #device = images.device n, c, h, w = images.shape obs_patches = get_patches_fast(images, self._cfg) patches_g = get_patches_fast(goal_imgs, self._cfg) if self._cfg.dataset.encode_with_t5: goals_e = goals_txt B, T, E = goals_txt.shape else: goals_e = self.token_embedding_table(goals_txt) B, E = goals_txt.shape T = self._cfg.max_block_size # TODO: ## Provide the logic to produce the output and loss for the GRP # Map the vector corresponding to each patch to the hidden size dimension # Adding classification and goal_img tokens to the tokens # Adding positional embedding # Compute blocked masks # Transformer Blocks # Getting the classification token only # Compute output and loss obs_emb = self.patch_embedding(obs_patches) goal_img_emb = self.patch_embedding(patches_g) cls_tokens = self.cls_token.expand(B, -1, -1) sep_tokens = self.goal_token.expand(B, -1, -1) # Concatenate everything into one sequence: # [CLS] + [Goal Text] + [Goal Separator] + [Goal Img] + [Observation Patches] x = torch.cat((cls_tokens, goals_e, sep_tokens, goal_img_emb, obs_emb), dim=1) # Adding positional embedding # We slice the pre-calculated buffer to the current sequence length seq_len = x.shape[1] x = x + self.pos_embedding[:seq_len, :].to(self._cfg.device) x = self.dropout(x) # Compute blocked masks mask = None if mask_: # Create a causal mask (lower triangular) mask = torch.tril(torch.ones(seq_len, seq_len, device=device)) # Transformer Blocks for block in self.blocks: x = block(x, mask) x = self.ln_f(x) # Getting the classification token only (index 0) # This token aggregates information from the entire sequence cls_out = x[:, 0, :] # Compute output logits = self.lm_head(cls_out) # Shape: (B, action_dim * action_stacking) # Compute loss loss = None if targets is not None: # Typically MSE loss for continuous action regression loss = F.mse_loss(logits, targets) return (logits, loss) def resize_image(self, image): """ Docstring for resize_image :param self: Description :param image: Description self._resize_state = lambda sf: cv2.resize(np.array(sf, dtype=np.float32), (cfg.image_shape[0], cfg.image_shape[1])) # resize state """ import cv2 import numpy as _np img = _np.array(image, dtype=_np.float32) img = cv2.resize(img, (self._cfg.image_shape[0], self._cfg.image_shape[1])) return img def normalize_state(self, image): """ Docstring for preprocess_state :param self: Description :param image: Description self._encode_state = lambda af: ((af/(255.0)*2.0)-1.0) # encoder: take a float, output an integer self._resize_state = lambda sf: cv2.resize(np.array(sf, dtype=np.float32), (cfg.image_shape[0], cfg.image_shape[1])) # resize state """ # img = _np.array(image, dtype=_np.float32) # img = cv2.resize(img, (self._cfg.image_shape[0], self._cfg.image_shape[1])) enc = ((image / 255.0) * 2.0) - 1.0 # t = _torch.tensor(enc, dtype=_torch.float32, device=self._cfg.device) return enc def preprocess_state(self, image): img = self.resize_image(image) img = self.normalize_state(img) return img def preprocess_goal_image(self, image): return self.preprocess_state(image) def encode_text_goal(self, goal, tokenizer=None, text_model=None): import numpy as _np import torch as _torch if self._cfg.dataset.encode_with_t5: if tokenizer is None or text_model is None: raise ValueError("tokenizer and text_model must be provided when using T5 encoding") # TODO: ## Provide the logic converting text goal to T5 embedding tensor with _torch.no_grad(): # Tokenize the goal text inputs = tokenizer(goal, return_tensors="pt", padding=True, truncation=True, max_length=self._cfg.max_block_size) # Move inputs to the correct device inputs = {k: v.to(self._cfg.device) for k, v in inputs.items()} # Get embeddings from the text model (Encoder) outputs = text_model(**inputs) # Use last hidden state: (Batch, Seq_Len, Hidden_Dim) embeddings = outputs.last_hidden_state return embeddings else: pad = " " * self._cfg.max_block_size goal_ = goal[:self._cfg.max_block_size] + pad[len(goal):self._cfg.max_block_size] try: stoi = {c: i for i, c in enumerate(self._cfg.dataset.chars_list)} ids = [stoi.get(c, 0) for c in goal_] except Exception: ids = [0] * self._cfg.max_block_size return _torch.tensor(_np.expand_dims(_np.array(ids, dtype=_np.int64), axis=0), dtype=_torch.long, device=self._cfg.device) def process_text_embedding_for_buffer(self, goal, tokenizer=None, text_model=None): """ Process text goal embedding for storing in the circular buffer. Returns a numpy array of shape (max_block_size, n_embd) without batch dimension. """ import numpy as _np if tokenizer is None or text_model is None: raise ValueError("tokenizer and text_model must be provided when using T5 encoding") goal_ = _np.zeros((self._cfg.max_block_size, self._cfg.n_embd), dtype=_np.float32) input_ids = tokenizer(goal, return_tensors="pt").input_ids goal_t = text_model.encoder(input_ids).last_hidden_state.detach().cpu().numpy() goal_[:len(goal_t[0]), :] = goal_t[0][:self._cfg.max_block_size] return goal_ def decode_action(self, action_tensor): """ Docstring for decode_action :param self: Description :param action_tensor: Description self._decode_action = lambda binN: (binN * action_std) + action_mean # Undo mapping to [-1, 1] """ import torch as _torch ## The action tensor is of shape (batch_size, action_dim * action_stacking) so we need to repeat the mean and std per action stacking action_mean = _torch.tensor(np.repeat(self._cfg.dataset.action_mean, self._cfg.policy.action_stacking), dtype=action_tensor.dtype, device=action_tensor.device) action_std = _torch.tensor(np.repeat(self._cfg.dataset.action_std, self._cfg.policy.action_stacking), dtype=action_tensor.dtype, device=action_tensor.device) return (action_tensor * action_std) + action_mean def encode_action(self, action_float): """ Docstring for encode_action :param self: Description :param action_float: Description self._encode_action = lambda af: (af - action_mean)/(action_std) # encoder: take a float, output an integer """ import torch as _torch action_mean = _torch.tensor(self._cfg.dataset.action_mean, dtype=action_float.dtype, device=action_float.device) action_std = _torch.tensor(self._cfg.dataset.action_std, dtype=action_float.dtype, device=action_float.device) return (action_float - action_mean) / action_std @torch.no_grad() def estimate_loss(model, dataset): out = {} model.eval() for split in ['train', 'val']: losses = torch.zeros(model._cfg.eval_iters) for k in range(model._cfg.eval_iters): X, x_pose, x_goal, x_goal_img, Y = dataset.get_batch_grp(split, model._cfg, model._cfg.batch_size) logits, loss = model(X, x_goal, x_goal_img, Y, pose=x_pose) losses[k] = loss.item() out[split] = losses.mean() model.train() return out