Kernels
optimizer / test /conftest.py
wyldecat's picture
Add torch.compile, CUDA graph, and compiled momentum [skip-build]
e74d98f
import logging
import pytest
import torch
import torch.distributed as dist
from packaging import version
from transformers import AutoModelForCausalLM
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# Raise dynamo recompile limit so that compiled momentum (batch_pre_ortho)
# does not fall back to eager mode when the test suite runs 30+ model
# configurations with different tensor shapes in a single process.
torch._dynamo.config.recompile_limit = 64
SEED = 0xdeadbeef
def pytest_addoption(parser):
parser.addoption(
"--measure-perf",
action="store_true",
default=False,
help=
"Measure execution time and peak memory usage during optimizer step.",
)
parser.addoption(
"--do-profile",
action="store_true",
default=False,
help="Enable profiling during tests.",
)
parser.addoption(
"--skip-verify",
action="store_true",
default=False,
help=
"Skip verification of optimizer step correctness with sequential implementation.\n"
"This can be useful when GPU memory is limited.",
)
def pytest_configure(config):
if config.getoption(
"--do-profile") and not config.getoption("--measure-perf"):
raise pytest.UsageError(
"--do-profile requires --measure-perf. Please enable both flags.")
@pytest.fixture(scope="session")
def measure_perf(request):
return request.config.getoption("--measure-perf")
@pytest.fixture(scope="session")
def do_profile(request):
return request.config.getoption("--do-profile")
@pytest.fixture(scope="session")
def skip_verify(request):
return request.config.getoption("--skip-verify")
@pytest.fixture(scope="session", autouse=True)
def init_dist(request):
if version.parse(torch.__version__) < version.parse("2.8"):
pytest.skip("torch>=2.8.0 is required for parallel muon")
return
try:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
except Exception as e:
print(f"Failed to initialize torch.distributed: {e}")
pytest.skip("Failed to initialize torch.distributed")
if dist.get_world_size() != 8:
pytest.skip("Need 8 processes in dist group. "
"You can run with `torchrun --nproc-per-node=8 "
"--local-ranks-filter 0 -m pytest "
"test_rms_norm_sequence_parallel.py`."
"To run with less than 8 gpus, modify "
"the test cases accordingly.")
yield
dist.destroy_process_group()
@pytest.fixture(scope="session")
def inputs():
"""Load Motif-2.6B model and generate random gradients for testing.
Returns:
tuple[torch.nn.Module, list[torch.Tensor], dict[int, torch.Tensor]]:
- torch.nn.Module: The Motif-2.6B model.
- list[torch.Tensor]: A list of random gradients for each model parameter.
- dict[int, torch.Tensor]: A dictionary mapping layer indices to random QK logits.
"""
model_name = "Motif-Technologies/Motif-2.6B-4layer-random"
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
)
logger.info(
f"Loaded model {model_name}. ({len(list(model.parameters()))} parameters)"
)
grads: list[torch.Tensor] = []
for param in model.parameters():
grad = torch.randn_like(param, device=param.device, dtype=param.dtype)
grads.append(grad)
qk_logits: dict[int, torch.Tensor] = {
i:
torch.randn(model.config.num_attention_heads,
device=model.device,
dtype=torch.bfloat16)
for i in range(model.config.num_hidden_layers)
}
return [model, grads, qk_logits]
def _create_moe_model(num_experts=8, top_k=2, n_layers=4):
"""Create a torchtitan Llama4 MoE model with random gradients."""
from torchtitan.models.llama4.model.args import TransformerModelArgs
from torchtitan.models.llama4.model.model import Transformer
from torchtitan.models.moe import MoEArgs
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
moe_args = MoEArgs(
num_experts=num_experts,
num_shared_experts=1,
top_k=top_k,
score_func="sigmoid",
)
model_args = TransformerModelArgs(
dim=2048,
n_layers=n_layers,
n_heads=16,
n_kv_heads=8,
vocab_size=32000,
norm_eps=1e-5,
rope_theta=10000,
max_seq_len=4096,
moe_args=moe_args,
interleave_moe_layer_step=1,
)
model = Transformer(model_args)
model.init_weights()
logger.info(f"Created torchtitan Llama4 MoE model "
f"(num_experts={num_experts}, n_layers={n_layers}, "
f"{len(list(model.parameters()))} parameters)")
grads = [
torch.randn_like(param, device=param.device, dtype=param.dtype)
for param in model.parameters()
]
return [model, grads]
@pytest.fixture(scope="session")
def moe_inputs():
"""MoE model with 8 experts (standard config)."""
return _create_moe_model(num_experts=8, top_k=2)
@pytest.fixture(scope="session")
def moe_inputs_few_experts():
"""MoE model with 2 experts (triggers EFSDP Shard(1) mode)."""
return _create_moe_model(num_experts=2, top_k=1)