| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | try: |
| | from apex.normalization import FusedLayerNorm as LayerNorm |
| | except ModuleNotFoundError: |
| | from torch.nn import LayerNorm |
| |
|
| |
|
| | class set_torch_seed(object): |
| | def __init__(self, seed): |
| | assert isinstance(seed, int) |
| | self.rng_state = self.get_rng_state() |
| |
|
| | torch.manual_seed(seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed(seed) |
| |
|
| | def get_rng_state(self): |
| | state = {"torch_rng_state": torch.get_rng_state()} |
| | if torch.cuda.is_available(): |
| | state["cuda_rng_state"] = torch.cuda.get_rng_state() |
| | return state |
| |
|
| | def set_rng_state(self, state): |
| | torch.set_rng_state(state["torch_rng_state"]) |
| | if torch.cuda.is_available(): |
| | torch.cuda.set_rng_state(state["cuda_rng_state"]) |
| |
|
| | def __enter__(self): |
| | return self |
| |
|
| | def __exit__(self, *exc): |
| | self.set_rng_state(self.rng_state) |
| |
|
| |
|
| | def make_experts(args, embed_dim, expert_ffn_dim): |
| | world_size = 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size() |
| | expert_list = [] |
| | ddp_rank = args.ddp_rank |
| | start_seed = torch.randint(1000000, (1,)).item() |
| | |
| | if args.moe_expert_count >= world_size: |
| | assert args.moe_expert_count % world_size == 0, f"{args.moe_expert_count}, {world_size}" |
| | local_moe_expert_count = args.moe_expert_count // world_size |
| | for i in range(local_moe_expert_count): |
| | with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i): |
| | expert_list.append( |
| | FeedForwardNetwork( |
| | embed_dim, |
| | expert_ffn_dim, |
| | args.activation_fn, |
| | args.dropout, |
| | args.activation_dropout, |
| | args.layernorm_eps, |
| | args.subln, |
| | ) |
| | ) |
| | else: |
| | assert world_size % args.moe_expert_count == 0, f"{world_size}, {args.moe_expert_count}" |
| |
|
| | with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count): |
| | expert_list.append( |
| | FeedForwardNetwork( |
| | embed_dim, |
| | expert_ffn_dim, |
| | args.activation_fn, |
| | args.dropout, |
| | args.activation_dropout, |
| | args.layernorm_eps, |
| | args.subln, |
| | ) |
| | ) |
| | experts = nn.ModuleList(expert_list) |
| | return experts |
| |
|
| |
|
| | def get_activation_fn(activation): |
| | if activation == "relu": |
| | return F.relu |
| | elif activation == "gelu": |
| | return F.gelu |
| | else: |
| | raise NotImplementedError |
| |
|
| |
|
| | class FeedForwardNetwork(nn.Module): |
| | def __init__( |
| | self, |
| | embed_dim, |
| | ffn_dim, |
| | activation_fn, |
| | dropout, |
| | activation_dropout, |
| | layernorm_eps, |
| | subln=False, |
| | ): |
| | super().__init__() |
| | self.embed_dim = embed_dim |
| | self.activation_fn = get_activation_fn(activation=str(activation_fn)) |
| | self.activation_dropout_module = torch.nn.Dropout(activation_dropout) |
| | self.dropout_module = torch.nn.Dropout(dropout) |
| | self.fc1 = nn.Linear(self.embed_dim, ffn_dim) |
| | self.fc2 = nn.Linear(ffn_dim, self.embed_dim) |
| | self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None |
| |
|
| | def reset_parameters(self): |
| | self.fc1.reset_parameters() |
| | self.fc2.reset_parameters() |
| | if self.ffn_layernorm is not None: |
| | self.ffn_layernorm.reset_parameters() |
| |
|
| | def forward(self, x): |
| | |
| | x = self.fc1(x) |
| | |
| | x = self.activation_fn(x) |
| | x = self.activation_dropout_module(x) |
| | if self.ffn_layernorm is not None: |
| | x = self.ffn_layernorm(x) |
| | x = self.fc2(x) |
| | |
| | x = self.dropout_module(x) |
| | return x |
| |
|