hour1's picture
Upload folder using huggingface_hub
9114cf2 verified
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Server starts a Trainer. Client sends data to the server to train.
"""
import os
os.environ["MEGATRON_USE_CUDA_TIMER"] = "0"
os.environ["MEGATRON_START_PROCESS_TIMER"] = "False"
os.environ["NCCL_DEBUG"] = "WARN"
import ray
import torch
from megatron.core import parallel_state as mpu
from megatron.core import tensor_parallel
from megatron.core.models.gpt.gpt_model import ModelType
from omegaconf import OmegaConf
from tensordict import TensorDict
from torch import nn
from transformers import LlamaConfig
from verl import DataProto
from verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.utils.megatron.optimizer import get_megatron_optimizer, init_megatron_optim_config
from verl.utils.megatron_utils import get_model, mcore_model_parallel_config
@ray.remote
class Trainer(Worker):
def __init__(self):
super().__init__()
if not torch.distributed.is_initialized():
rank = int(os.environ["LOCAL_RANK"])
torch.distributed.init_process_group(backend="nccl")
torch.cuda.set_device(rank)
mpu.initialize_model_parallel(
tensor_model_parallel_size=2,
pipeline_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None,
pipeline_model_parallel_split_rank=None,
use_sharp=False,
context_parallel_size=1,
expert_model_parallel_size=1,
nccl_communicator_config_path=None,
)
tensor_parallel.model_parallel_cuda_manual_seed(10)
is_collect = (
mpu.get_tensor_model_parallel_rank() == 0
and mpu.get_pipeline_model_parallel_rank() == mpu.get_pipeline_model_parallel_world_size() - 1
and mpu.get_context_parallel_rank() == 0
)
self._register_dispatch_collect_info(
mesh_name="train", dp_rank=mpu.get_data_parallel_rank(), is_collect=is_collect
)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
actor_model_config = LlamaConfig(
vocab_size=256,
hidden_size=2048,
intermediate_size=5504,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16,
)
megatron_config = mcore_model_parallel_config(sequence_parallel=True, params_dtype=torch.bfloat16)
self.megatron_config = megatron_config
def megatron_actor_model_provider(pre_process, post_process):
# vpp is not supported yet because it will hang for some reason. Need debugging
# this_megatron_config = copy.deepcopy(megatron_config)
# this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank
parallel_model = ParallelLlamaForCausalLMRmPadPP(
config=actor_model_config,
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process,
)
parallel_model.cuda()
return parallel_model
actor_module = get_model(
model_provider_func=megatron_actor_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True,
)
actor_module = nn.ModuleList(actor_module)
optim_config = OmegaConf.create({"lr": 1e-6, "clip_grad": 1.0})
optim_config = init_megatron_optim_config(optim_config)
self.optimizer_config = optim_config
actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config)
self.model = actor_module[0]
self.optimizer = actor_optimizer
@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"))
def train_model(self, data: DataProto) -> DataProto:
input_ids = data.batch["input_ids"]
attention_mask = data.batch["attention_mask"]
position_ids = data.batch["position_ids"]
self.optimizer.zero_grad()
self.model.zero_grad_buffer(
zero_buffer=(not self.optimizer_config.use_distributed_optimizer)
) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
# update for 1 iteration
output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits
output.mean().backward()
update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(
self.megatron_config, self.megatron_config.timers
)
return DataProto(batch=TensorDict({"loss": output.detach()}, batch_size=output.shape[0]))
if __name__ == "__main__":
ray.init(address="auto", namespace="verl")
resource_pool = RayResourcePool(process_on_nodes=[2], detached=True)
cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
worker_group = RayWorkerGroup(
resource_pool=resource_pool,
ray_cls_with_init=cls_with_init_args,
name_prefix="trainer",
detached=True,
)
worker_group.init_model()
worker_names = worker_group.worker_names
print(worker_names)