| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| try: |
| from apex.normalization import FusedLayerNorm as LayerNorm |
| except ModuleNotFoundError: |
| from torch.nn import LayerNorm |
|
|
|
|
| class set_torch_seed(object): |
| def __init__(self, seed): |
| assert isinstance(seed, int) |
| self.rng_state = self.get_rng_state() |
|
|
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(seed) |
|
|
| def get_rng_state(self): |
| state = {"torch_rng_state": torch.get_rng_state()} |
| if torch.cuda.is_available(): |
| state["cuda_rng_state"] = torch.cuda.get_rng_state() |
| return state |
|
|
| def set_rng_state(self, state): |
| torch.set_rng_state(state["torch_rng_state"]) |
| if torch.cuda.is_available(): |
| torch.cuda.set_rng_state(state["cuda_rng_state"]) |
|
|
| def __enter__(self): |
| return self |
|
|
| def __exit__(self, *exc): |
| self.set_rng_state(self.rng_state) |
|
|
|
|
| def make_experts(args, embed_dim, expert_ffn_dim): |
| world_size = 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size() |
| expert_list = [] |
| ddp_rank = args.ddp_rank |
| start_seed = torch.randint(1000000, (1,)).item() |
| |
| if args.moe_expert_count >= world_size: |
| assert args.moe_expert_count % world_size == 0, f"{args.moe_expert_count}, {world_size}" |
| local_moe_expert_count = args.moe_expert_count // world_size |
| for i in range(local_moe_expert_count): |
| with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i): |
| expert_list.append( |
| FeedForwardNetwork( |
| embed_dim, |
| expert_ffn_dim, |
| args.activation_fn, |
| args.dropout, |
| args.activation_dropout, |
| args.layernorm_eps, |
| args.subln, |
| ) |
| ) |
| else: |
| assert world_size % args.moe_expert_count == 0, f"{world_size}, {args.moe_expert_count}" |
|
|
| with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count): |
| expert_list.append( |
| FeedForwardNetwork( |
| embed_dim, |
| expert_ffn_dim, |
| args.activation_fn, |
| args.dropout, |
| args.activation_dropout, |
| args.layernorm_eps, |
| args.subln, |
| ) |
| ) |
| experts = nn.ModuleList(expert_list) |
| return experts |
|
|
|
|
| def get_activation_fn(activation): |
| if activation == "relu": |
| return F.relu |
| elif activation == "gelu": |
| return F.gelu |
| else: |
| raise NotImplementedError |
|
|
|
|
| class FeedForwardNetwork(nn.Module): |
| def __init__( |
| self, |
| embed_dim, |
| ffn_dim, |
| activation_fn, |
| dropout, |
| activation_dropout, |
| layernorm_eps, |
| subln=False, |
| ): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.activation_fn = get_activation_fn(activation=str(activation_fn)) |
| self.activation_dropout_module = torch.nn.Dropout(activation_dropout) |
| self.dropout_module = torch.nn.Dropout(dropout) |
| self.fc1 = nn.Linear(self.embed_dim, ffn_dim) |
| self.fc2 = nn.Linear(ffn_dim, self.embed_dim) |
| self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None |
|
|
| def reset_parameters(self): |
| self.fc1.reset_parameters() |
| self.fc2.reset_parameters() |
| if self.ffn_layernorm is not None: |
| self.ffn_layernorm.reset_parameters() |
|
|
| def forward(self, x): |
| |
| x = self.fc1(x) |
| |
| x = self.activation_fn(x) |
| x = self.activation_dropout_module(x) |
| if self.ffn_layernorm is not None: |
| x = self.ffn_layernorm(x) |
| x = self.fc2(x) |
| |
| x = self.dropout_module(x) |
| return x |
|
|