| | """ |
| | train.py |
| | |
| | Main training script for VITRA Vision-Language-Action (VLA) models. |
| | Supports distributed training with FSDP (Fully Sharded Data Parallel) strategy. |
| | """ |
| |
|
| | import argparse |
| | import copy |
| | import datetime |
| | import faulthandler |
| | import json |
| | import os |
| | import random |
| | from pathlib import Path |
| | from typing import Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.distributed as dist |
| | import wandb |
| | from torch.utils.data import DataLoader |
| |
|
| | from vitra.datasets.materialize import get_vla_dataset_and_collator |
| | from vitra.models.vla_builder import build_vla, load_vla_checkpoint |
| | from vitra.training import VLAMetrics |
| | from vitra.utils import ( |
| | find_last_checkpoint, |
| | get_epoch_and_step_from_checkpoint, |
| | set_global_seed, |
| | setup_seed, |
| | ) |
| | from vitra.training.fsdp import VLAFSDPStrategy |
| | from vitra.utils.config_utils import load_config |
| | from vitra.utils.overwatch import initialize_overwatch |
| |
|
| | |
| | |
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| | |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| |
|
| | |
| | overwatch = initialize_overwatch(__name__) |
| |
|
| | def experiment(variant): |
| | """ |
| | Main training experiment function for VITRA VLA models. |
| | |
| | Args: |
| | variant: Configuration dictionary containing all training parameters including: |
| | - Model architecture settings |
| | - Training hyperparameters |
| | - Dataset configurations |
| | - Logging and checkpoint paths |
| | """ |
| | |
| | torch.cuda.set_device(device_id := overwatch.local_rank()) |
| | torch.cuda.empty_cache() |
| | |
| | |
| | overwatch.info("VITRA VLA Training :: Creating Folders", ctx_level=1) |
| | wandb_api_key = os.getenv("WANDB_API_KEY") |
| | if wandb_api_key is None: |
| | raise ValueError("Please set the WANDB_API_KEY environment variable.") |
| | wandb.login(key=wandb_api_key) |
| | |
| | |
| | os.makedirs(variant["log_root"], exist_ok=True) |
| | os.makedirs(variant["output_root"], exist_ok=True) |
| | os.makedirs(variant["cache_root"], exist_ok=True) |
| | |
| | |
| | |
| | run_id = variant["task_name"] if "task_name" in variant else None |
| | batch_size = variant["batch_size"] |
| | total_batch_size = variant["total_batch_size"] |
| | run_id = f"{run_id}_TB{total_batch_size}_B{batch_size}_bf16{variant['use_bf16']}" |
| | |
| | checkpoint_dir = os.path.join(variant["output_root"], run_id) |
| | os.makedirs(checkpoint_dir, exist_ok=True) |
| | |
| | |
| | worker_init_fn = set_global_seed(variant["seed"], get_worker_init_fn=True) |
| |
|
| | |
| | def posix_to_str(d): |
| | if isinstance(d, dict): |
| | return {k: posix_to_str(v) for k, v in d.items()} |
| | elif isinstance(d, list): |
| | return [posix_to_str(v) for v in d] |
| | elif isinstance(d, Path): |
| | return str(d) |
| | else: |
| | return d |
| | |
| | variant_str = copy.deepcopy(variant) |
| | copied_variant = posix_to_str(variant_str) |
| |
|
| | if overwatch.rank() == 0: |
| | with open(os.path.join(checkpoint_dir, "config.json"), "w") as f: |
| | json.dump(copied_variant, f, indent=2) |
| | overwatch.info(f"Config saved to {checkpoint_dir}", ctx_level=1) |
| | print(json.dumps(copied_variant, indent=2)) |
| |
|
| | dist.barrier() |
| | |
| | |
| | overwatch.info("Loading model", ctx_level=1) |
| | resume_step = 0 |
| | resume_epoch = 0 |
| | model_load_path = variant["model_load_path"] |
| | |
| | |
| | if variant["resume"]: |
| | |
| | if model_load_path is None: |
| | model_load_path = find_last_checkpoint(checkpoint_dir) |
| | |
| | |
| | if model_load_path is not None: |
| | resume_epoch, resume_step = get_epoch_and_step_from_checkpoint(model_load_path) |
| | if overwatch.rank() == 0: |
| | overwatch.info( |
| | f"Resume from {model_load_path}, epoch: {resume_epoch}, step: {resume_step}", |
| | ctx_level=1 |
| | ) |
| |
|
| | |
| | model = build_vla(configs=variant) |
| | pretrain_path = variant.get("pretrain_path", None) |
| | if variant['resume'] and model_load_path is not None: |
| | model = load_vla_checkpoint(model, os.path.join(model_load_path, "weights.pt")) |
| | elif pretrain_path is not None: |
| | if os.path.isdir(pretrain_path): |
| | model = load_vla_checkpoint(model, os.path.join(pretrain_path, "weights.pt")) |
| | else: |
| | model = load_vla_checkpoint(model, pretrain_path) |
| |
|
| | model = model.train() |
| | model.trainable_params_setup() |
| | model.model.use_bf16 = variant["use_bf16"] |
| | model.use_bf16 = variant["use_bf16"] |
| |
|
| | |
| | if variant.get("debug", False): |
| | for p in model.model.parameters(): |
| | p.requires_grad = False |
| |
|
| | |
| | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | all_params = sum(p.numel() for p in model.parameters()) |
| | if overwatch.rank() == 0: |
| | overwatch.info(f"Trainable Model Parameters: {total_params/1e6:.2f}M/{all_params/1e6:.2f}M") |
| | |
| | processor = model.processor |
| |
|
| | |
| | |
| | vla_dataset, collator, batch_sampler = get_vla_dataset_and_collator( |
| | variant["train_dataset"]["data_root_dir"], |
| | variant["train_dataset"]["data_mix"], |
| | augmentation=variant["train_dataset"]["augmentation"], |
| | shard_num=dist.get_world_size(), |
| | shard_index=dist.get_rank(), |
| | seed=variant["seed"], |
| | future_action_window_size=variant["fwd_pred_next_n"] - 1, |
| | processor=processor, |
| | batch_size=batch_size, |
| | normalization=variant["train_dataset"].get("normalization", True), |
| | flip_augmentation=variant["train_dataset"].get("flip_augmentation", 1.0), |
| | set_none_ratio=variant["train_dataset"].get("set_none_ratio", 0.0), |
| | action_type=variant["train_dataset"].get('action_type', 'angle'), |
| | use_rel=variant["train_dataset"].get('use_rel', False), |
| | rel_mode=variant["train_dataset"].get('rel_mode', "step"), |
| | clip_len=variant["train_dataset"].get('clip_len', None), |
| | state_mask_prob=variant["train_dataset"].get('state_mask_prob', 0.1), |
| | ) |
| | |
| | |
| | |
| | training_strategy = VLAFSDPStrategy( |
| | vla=model, |
| | device_id=overwatch.local_rank(), |
| | stage=None, |
| | epochs=variant["trainer"]["max_epochs"], |
| | max_steps=variant["trainer"]["max_steps"], |
| | global_batch_size=variant["total_batch_size"], |
| | per_device_batch_size=batch_size, |
| | learning_rate=variant["trainer"]["learning_rate"], |
| | weight_decay=variant["trainer"]["weight_decay"], |
| | max_grad_norm=variant["trainer"]["gradient_clip_val"], |
| | lr_scheduler_type=variant["trainer"]["lr_scheduler_type"], |
| | warmup_ratio=variant["trainer"]["warmup_ratio"], |
| | enable_gradient_checkpointing=variant["trainer"]["enable_gradient_checkpointing"], |
| | enable_mixed_precision_training=variant["trainer"]["enable_mixed_precision_training"], |
| | reduce_in_full_precision=variant["trainer"]["reduce_in_full_precision"], |
| | action_model_learning_rate=variant["trainer"].get("action_model_learning_rate", None), |
| | action_model_weight_decay=variant["trainer"].get("action_model_weight_decay", None), |
| | sharding_strategy=variant["trainer"].get("sharding_strategy", "shard-grad-op"), |
| | cognition_token_weight_decay=variant["trainer"].get("cognition_token_weight_decay", True), |
| | llm_freeze_step=variant["trainer"].get("llm_freeze_step", 0), |
| | move_word_embedding_to_action_model=variant["trainer"].get("move_word_embedding_to_action_head", False), |
| | optimizer_betas=variant["trainer"].get("optimizer_betas", (0.9, 0.999)), |
| | ) |
| | |
| | |
| | |
| | if variant["vla_name"] == "VITRA_Paligemma": |
| | auto_wrap_policy, checkpointing_policy = get_fsdp_wrap_policy_and_checkpointing(variant["trainer"]) |
| | else: |
| | raise NotImplementedError(f"Unsupported VLA name: {variant['vla_name']}") |
| | |
| | |
| | training_strategy.run_setup( |
| | run_dir=checkpoint_dir, |
| | n_train_examples=len(vla_dataset), |
| | auto_wrap_policy_modules=auto_wrap_policy, |
| | checkpointing_policy_modules=checkpointing_policy, |
| | ) |
| | |
| | |
| | if variant["resume"] == True and model_load_path is not None: |
| | training_strategy.load_optimizer_and_scheduler(model_load_path) |
| | |
| | |
| | |
| | trackers = ["wandb"] |
| | overwatch.info(f"Creating Metrics with Active Trackers => `{trackers}`") |
| | metrics = VLAMetrics( |
| | trackers, |
| | hparams=variant_str, |
| | run_id=run_id, |
| | run_dir=checkpoint_dir, |
| | wandb_project=variant["wandb_project"], |
| | wandb_entity=variant["wandb_entity"], |
| | resume_step=resume_step, |
| | resume_epoch=resume_epoch, |
| | ) |
| | |
| | |
| | overwatch.info("Creating Dataloader", ctx_level=1) |
| | |
| | num_workers = variant["num_workers"] if variant["num_workers"] is not None else variant["train_dataset"]["num_workers"] |
| | prefetch_factor = variant["prefetch_factor"] if variant["prefetch_factor"] is not None else variant["train_dataset"]["prefetch_factor"] |
| |
|
| | if num_workers == 0 or prefetch_factor == 0: |
| | prefetch_factor = None |
| |
|
| | if overwatch.rank() == 0: |
| | print(f"num_workers: {num_workers}, prefetch_factor: {prefetch_factor}") |
| | |
| | |
| | batch_sampler.set_epoch(resume_epoch, resume_step * training_strategy.grad_accumulation_steps) |
| |
|
| | setup_seed(variant["seed"], rank=torch.distributed.get_rank()) |
| |
|
| | |
| | dataloader = DataLoader( |
| | vla_dataset, |
| | batch_sampler=batch_sampler, |
| | collate_fn=collator, |
| | num_workers=num_workers, |
| | prefetch_factor=prefetch_factor, |
| | worker_init_fn=worker_init_fn, |
| | persistent_workers=num_workers > 0, |
| | pin_memory=num_workers > 0, |
| | ) |
| |
|
| | |
| | overwatch.info("Starting VLA Training Loop") |
| | training_strategy.run_training( |
| | dataloader, |
| | metrics, |
| | save_interval=variant["save_steps"], |
| | start_global_step=resume_step, |
| | start_epoch=resume_epoch, |
| | ) |
| |
|
| | |
| | overwatch.info("Done with Training =>> Finalizing Metrics") |
| | metrics.finalize() |
| |
|
| | |
| | overwatch.info("... and that's all, folks!") |
| | dist.barrier() |
| | dist.destroy_process_group() |
| |
|
| | def get_fsdp_wrap_policy_and_checkpointing(configs): |
| | """ |
| | Get FSDP auto-wrapping policy and activation checkpointing policy for PaliGemma models. |
| | |
| | The auto-wrap policy determines which module types should be individually wrapped by FSDP, |
| | allowing for efficient memory usage and communication in distributed training. |
| | |
| | The checkpointing policy determines which modules should use activation checkpointing |
| | (gradient checkpointing) to trade computation for memory during training. |
| | |
| | Args: |
| | configs: Trainer configuration dictionary containing strategy settings |
| | |
| | Returns: |
| | Tuple of (auto_wrap_policy, checkpointing_policy): |
| | - auto_wrap_policy: Set of module classes to wrap with FSDP |
| | - checkpointing_policy: Set of module classes to apply gradient checkpointing, or None |
| | """ |
| | if 'strategy' not in configs or configs['strategy'] == 'ddp': |
| | raise NotImplementedError("FSDP strategy not specified or DDP selected.") |
| | |
| | |
| | from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer |
| | from transformers.models.paligemma.modeling_paligemma import PaliGemmaMultiModalProjector |
| | from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer, SiglipVisionTransformer |
| | |
| | from vitra.models.action_model import DiT |
| | from vitra.utils.nn_utils import MLPProjector |
| | |
| | |
| | policy = { |
| | SiglipEncoderLayer, |
| | SiglipVisionTransformer, |
| | DiT, |
| | Gemma2DecoderLayer, |
| | PaliGemmaMultiModalProjector, |
| | MLPProjector |
| | } |
| | |
| | |
| | checkpointing_policy = ( |
| | {Gemma2DecoderLayer} |
| | if configs["strategy"] == "fsdp_paligemma_with_checkpointing" |
| | else None |
| | ) |
| | |
| | return policy, checkpointing_policy |
| |
|
| | def update_configs(configs, args): |
| | """ |
| | Update configuration dictionary with command-line arguments. |
| | |
| | Command-line arguments take precedence over config file values. This function |
| | handles both top-level parameters and nested dictionaries (e.g., trainer settings). |
| | |
| | Args: |
| | configs: Base configuration dictionary loaded from YAML/JSON config file |
| | args: Parsed command-line arguments dictionary |
| | |
| | Returns: |
| | Updated configuration dictionary with command-line overrides applied |
| | """ |
| | if args["task_name"] is not None: |
| | configs["task_name"] = args["task_name"] |
| | |
| | configs["use_bf16"] = ( |
| | args["use_bf16"] |
| | if args["use_bf16"] is not None |
| | else configs.get("use_bf16", False) |
| | ) |
| |
|
| | if args["data_mix"] is not None: |
| | configs["train_dataset"]["data_mix"] = args["data_mix"] |
| | |
| | configs["output_root"] = Path(configs["output_root"]) |
| | configs["log_root"] = Path(configs["log_root"]) |
| | configs["cache_root"] = Path(configs["cache_root"]) / configs["model"] |
| |
|
| | |
| | for k, v in args.items(): |
| | if k not in configs: |
| | print(f"{k} not in config. The value is {v}.") |
| | configs[k] = v |
| | elif isinstance(v, dict): |
| | for sub_k, sub_v in v.items(): |
| | if sub_v is not None: |
| | configs[k][sub_k] = sub_v |
| | elif v is not None: |
| | configs[k] = v |
| | |
| | return configs |
| |
|
| | def parse_args(): |
| | """ |
| | Parse command-line arguments for training configuration. |
| | |
| | Arguments are organized into two groups: |
| | 1. Global arguments (experiment settings, paths, data configuration) |
| | 2. Trainer arguments (training hyperparameters and strategy) |
| | |
| | Returns: |
| | Dictionary with structure: |
| | { |
| | 'config': str, |
| | 'seed': int, |
| | ...other global args..., |
| | 'trainer': { |
| | 'strategy': str, |
| | 'gradient_clip_val': float, |
| | ...other trainer args... |
| | } |
| | } |
| | """ |
| | parser = argparse.ArgumentParser(description="VITRA VLA Training Script") |
| | |
| | |
| | parser.add_argument( |
| | "--config", |
| | type=str, |
| | help="Path to YAML/JSON configuration file for training" |
| | ) |
| | parser.add_argument( |
| | "--seed", |
| | default=None, |
| | type=int, |
| | help="Random seed for reproducibility" |
| | ) |
| | parser.add_argument( |
| | "--log_root", |
| | default=None, |
| | type=str, |
| | help="Root directory for logging" |
| | ) |
| | parser.add_argument( |
| | "--output_root", |
| | default=None, |
| | type=str, |
| | help="Root directory for checkpoints and outputs" |
| | ) |
| | parser.add_argument( |
| | "--model_load_path", |
| | default=None, |
| | type=str, |
| | help="Path to checkpoint for resuming training" |
| | ) |
| | parser.add_argument( |
| | "--task_name", |
| | default=None, |
| | type=str, |
| | help="Unique identifier for this training run" |
| | ) |
| | parser.add_argument( |
| | "--use_bf16", |
| | default=None, |
| | action="store_true", |
| | help="Enable bfloat16 mixed precision training" |
| | ) |
| | parser.add_argument( |
| | "--data_mix", |
| | default=None, |
| | type=str, |
| | help="Dataset mixture configuration" |
| | ) |
| | parser.add_argument( |
| | "--debug", |
| | default=False, |
| | action="store_true", |
| | help="Enable debug mode (freezes model parameters)" |
| | ) |
| | parser.add_argument( |
| | "--fwd_pred_next_n", |
| | default=None, |
| | type=int, |
| | help="Number of future action steps to predict" |
| | ) |
| | parser.add_argument( |
| | "--batch_size", |
| | default=None, |
| | type=int, |
| | help="Per-device batch size" |
| | ) |
| | parser.add_argument( |
| | "--total_batch_size", |
| | default=None, |
| | type=int, |
| | help="Global batch size across all devices" |
| | ) |
| | parser.add_argument( |
| | "--num_workers", |
| | default=None, |
| | type=int, |
| | help="Number of data loading workers per process" |
| | ) |
| | parser.add_argument( |
| | "--prefetch_factor", |
| | default=None, |
| | type=int, |
| | help="Number of batches to prefetch per worker" |
| | ) |
| | |
| | |
| | global_names = set(vars(parser.parse_known_args()[0]).keys()) |
| |
|
| | |
| | trainer_parser = parser.add_argument_group("trainer", "Training strategy and hyperparameters") |
| | trainer_parser.add_argument( |
| | "--strategy", |
| | default=None, |
| | type=str, |
| | help="Training strategy (e.g., 'fsdp')" |
| | ) |
| | trainer_parser.add_argument( |
| | "--gradient_clip_val", |
| | default=None, |
| | type=float, |
| | help="Maximum gradient norm for clipping" |
| | ) |
| | trainer_parser.add_argument( |
| | "--max_steps", |
| | default=None, |
| | type=int, |
| | help="Maximum number of training steps (overrides epochs)" |
| | ) |
| | |
| | |
| | trainer_names = set(vars(parser.parse_known_args()[0]).keys()) - global_names |
| |
|
| | |
| | args = {} |
| | trainer_args = {} |
| | temp_args = vars(parser.parse_args()) |
| | |
| | |
| | for k, v in temp_args.items(): |
| | if k in global_names: |
| | args[k] = v |
| | elif k in trainer_names: |
| | trainer_args[k] = v |
| |
|
| | |
| | args["trainer"] = trainer_args |
| |
|
| | return args |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | faulthandler.enable() |
| |
|
| | args = parse_args() |
| |
|
| | configs = load_config(args.get("config")) |
| | configs = update_configs(configs, args) |
| | |
| | |
| | if not dist.is_initialized(): |
| | dist.init_process_group(backend="nccl") |
| |
|
| | experiment(variant=configs) |