| | import torch |
| |
|
| | from collections import namedtuple |
| |
|
| |
|
| | def test_megablocks_moe_mlp_import(): |
| | """Test if MegaBlocksMoeMLP can be imported.""" |
| | from megablocks.layers import MegaBlocksMoeMLP |
| |
|
| | assert MegaBlocksMoeMLP is not None, "MegaBlocksMoeMLP import failed." |
| |
|
| |
|
| | def test_megablocks_moe_mlp_functionality(): |
| | """Test the functionality of MegaBlocksMoeMLP.""" |
| | from megablocks.layers import MegaBlocksMoeMLP |
| |
|
| | |
| | model = MegaBlocksMoeMLP() |
| |
|
| | |
| | model.experts = namedtuple( |
| | "Experts", |
| | [ |
| | "gate_up_proj", |
| | "gate_down_proj", |
| | "down_proj", |
| | "hidden_size", |
| | ], |
| | ) |
| |
|
| | num_experts = 128 |
| | hidden_size = 1152 |
| | intermediate_size = 3072 |
| |
|
| | |
| | ne, hs, isz = num_experts, hidden_size, intermediate_size |
| |
|
| | model.router = torch.nn.Linear(hs, ne).cuda() |
| | model.router.weight.data.fill_(1) |
| |
|
| | e = model.experts |
| | e.gate_up_proj = torch.nn.Parameter(torch.ones(ne, hs, isz, device="cuda")) |
| | e.gate_up_proj_bias = torch.nn.Parameter(torch.zeros(ne, isz, device="cuda")) |
| | e.down_proj = torch.nn.Parameter(torch.ones(ne, 1536, hs, device="cuda")) |
| | e.down_proj_bias = torch.nn.Parameter(torch.zeros(ne, hs, device="cuda")) |
| | e.hidden_size = hs |
| |
|
| | |
| | x = torch.randn(1, 1, 1152).to(torch.device("cuda")) |
| | output, expert_weights_out = model(x) |
| |
|
| | |
| | assert output.shape == (1, 1, 1152), "Output shape mismatch." |
| |
|