| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Server starts a Trainer. Client sends data to the server to train. |
| """ |
|
|
| import os |
|
|
| os.environ["MEGATRON_USE_CUDA_TIMER"] = "0" |
| os.environ["MEGATRON_START_PROCESS_TIMER"] = "False" |
| os.environ["NCCL_DEBUG"] = "WARN" |
|
|
| import ray |
| import torch |
| from megatron.core import parallel_state as mpu |
| from megatron.core import tensor_parallel |
| from megatron.core.models.gpt.gpt_model import ModelType |
| from omegaconf import OmegaConf |
| from tensordict import TensorDict |
| from torch import nn |
| from transformers import LlamaConfig |
|
|
| from verl import DataProto |
| from verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP |
| from verl.single_controller.base import Worker |
| from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register |
| from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup |
| from verl.utils.megatron.optimizer import get_megatron_optimizer, init_megatron_optim_config |
| from verl.utils.megatron_utils import get_model, mcore_model_parallel_config |
|
|
|
|
| @ray.remote |
| class Trainer(Worker): |
| def __init__(self): |
| super().__init__() |
|
|
| if not torch.distributed.is_initialized(): |
| rank = int(os.environ["LOCAL_RANK"]) |
| torch.distributed.init_process_group(backend="nccl") |
| torch.cuda.set_device(rank) |
|
|
| mpu.initialize_model_parallel( |
| tensor_model_parallel_size=2, |
| pipeline_model_parallel_size=1, |
| virtual_pipeline_model_parallel_size=None, |
| pipeline_model_parallel_split_rank=None, |
| use_sharp=False, |
| context_parallel_size=1, |
| expert_model_parallel_size=1, |
| nccl_communicator_config_path=None, |
| ) |
| tensor_parallel.model_parallel_cuda_manual_seed(10) |
|
|
| is_collect = ( |
| mpu.get_tensor_model_parallel_rank() == 0 |
| and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1 |
| and mpu.get_context_parallel_rank() == 0 |
| ) |
| self._register_dispatch_collect_info( |
| mesh_name="train", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect |
| ) |
|
|
| @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| def init_model(self): |
| actor_model_config = LlamaConfig( |
| vocab_size=256, |
| hidden_size=2048, |
| intermediate_size=5504, |
| num_hidden_layers=24, |
| num_attention_heads=16, |
| num_key_value_heads=16, |
| ) |
|
|
| megatron_config = mcore_model_parallel_config(sequence_parallel=True, params_dtype=torch.bfloat16) |
| self.megatron_config = megatron_config |
|
|
| def megatron_actor_model_provider(pre_process, post_process): |
| |
| |
| |
| parallel_model = ParallelLlamaForCausalLMRmPadPP( |
| config=actor_model_config, |
| megatron_config=megatron_config, |
| pre_process=pre_process, |
| post_process=post_process, |
| ) |
| parallel_model.cuda() |
| return parallel_model |
|
|
| actor_module = get_model( |
| model_provider_func=megatron_actor_model_provider, |
| model_type=ModelType.encoder_or_decoder, |
| wrap_with_ddp=True, |
| ) |
| actor_module = nn.ModuleList(actor_module) |
|
|
| optim_config = OmegaConf.create({"lr": 1e-6, "clip_grad": 1.0}) |
|
|
| optim_config = init_megatron_optim_config(optim_config) |
| self.optimizer_config = optim_config |
| actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config) |
|
|
| self.model = actor_module[0] |
| self.optimizer = actor_optimizer |
|
|
| @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train")) |
| def train_model(self, data: DataProto) -> DataProto: |
| input_ids = data.batch["input_ids"] |
| attention_mask = data.batch["attention_mask"] |
| position_ids = data.batch["position_ids"] |
|
|
| self.optimizer.zero_grad() |
| self.model.zero_grad_buffer( |
| zero_buffer=(not self.optimizer_config.use_distributed_optimizer) |
| ) |
| |
| output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits |
| output.mean().backward() |
|
|
| update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step( |
| self.megatron_config, self.megatron_config.timers |
| ) |
|
|
| return DataProto(batch=TensorDict({"loss": output.detach()}, batch_size=output.shape[0])) |
|
|
|
|
| if __name__ == "__main__": |
| ray.init(address="auto", namespace="verl") |
|
|
| resource_pool = RayResourcePool(process_on_nodes=[2], detached=True) |
| cls_with_init_args = RayClassWithInitArgs(cls=Trainer) |
| worker_group = RayWorkerGroup( |
| resource_pool=resource_pool, |
| ray_cls_with_init=cls_with_init_args, |
| name_prefix="trainer", |
| detached=True, |
| ) |
|
|
| worker_group.init_model() |
|
|
| worker_names = worker_group.worker_names |
| print(worker_names) |
|
|