| import torch |
|
|
| from collections import namedtuple |
|
|
|
|
| def test_megablocks_moe_mlp_import(): |
| """Test if MegaBlocksMoeMLP can be imported.""" |
| from megablocks.layers import MegaBlocksMoeMLP |
|
|
| assert MegaBlocksMoeMLP is not None, "MegaBlocksMoeMLP import failed." |
|
|
|
|
| def test_megablocks_moe_mlp_functionality(): |
| """Test the functionality of MegaBlocksMoeMLP.""" |
| from megablocks.layers import MegaBlocksMoeMLP |
|
|
| |
| model = MegaBlocksMoeMLP() |
|
|
| |
| model.experts = namedtuple( |
| "Experts", |
| [ |
| "gate_up_proj", |
| "gate_down_proj", |
| "down_proj", |
| "hidden_size", |
| ], |
| ) |
|
|
| num_experts = 128 |
| hidden_size = 1152 |
| intermediate_size = 3072 |
|
|
| |
| ne, hs, isz = num_experts, hidden_size, intermediate_size |
|
|
| model.router = torch.nn.Linear(hs, ne).cuda() |
| model.router.weight.data.fill_(1) |
|
|
| e = model.experts |
| e.gate_up_proj = torch.nn.Parameter(torch.ones(ne, hs, isz, device="cuda")) |
| e.gate_up_proj_bias = torch.nn.Parameter(torch.zeros(ne, isz, device="cuda")) |
| e.down_proj = torch.nn.Parameter(torch.ones(ne, 1536, hs, device="cuda")) |
| e.down_proj_bias = torch.nn.Parameter(torch.zeros(ne, hs, device="cuda")) |
| e.hidden_size = hs |
|
|
| |
| x = torch.randn(1, 1, 1152).to(torch.device("cuda")) |
| output, expert_weights_out = model(x) |
|
|
| |
| assert output.shape == (1, 1, 1152), "Output shape mismatch." |
|
|