File size: 3,585 Bytes
1f0c850 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 | import torch
from kernels.benchmark import Benchmark
def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
return torch.clamp(x.float() / scale.float(), -448.0, 448.0).to(torch.float8_e4m3fn)
def _dequant_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
return x.float() * scale.float()
def _compiler_disable(fn):
compiler = getattr(torch, "compiler", None)
if compiler is not None and hasattr(compiler, "disable"):
return compiler.disable(fn)
return torch._dynamo.disable(fn)
def _gelu_quantize_fp8_boundary(
hidden: torch.Tensor, bias: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
hidden = torch.nn.functional.gelu(
hidden.float() + bias.float(), approximate="tanh"
)
return _quantize_fp8(hidden, scale)
def _bf16_bias_add_boundary(out: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
return (out.float() + bias.float()).to(torch.bfloat16)
_stable_gelu_quantize_fp8 = _compiler_disable(_gelu_quantize_fp8_boundary)
_stable_bf16_bias_add = _compiler_disable(_bf16_bias_add_boundary)
class FP8GeluMlpBenchmark(Benchmark):
seed = 17
def _setup_shape(self, M: int, K: int, H: int, N: int) -> None:
self.M, self.K, self.H, self.N = M, K, H, N
self.x_scale = torch.tensor([0.05], device=self.device, dtype=torch.float32)
self.up_scale = torch.tensor([0.04], device=self.device, dtype=torch.float32)
self.hidden_scale = torch.tensor([0.25], device=self.device, dtype=torch.float32)
self.down_scale = torch.tensor([0.04], device=self.device, dtype=torch.float32)
self.x = _quantize_fp8(
torch.randn((M, K), device=self.device, dtype=torch.bfloat16),
self.x_scale,
)
self.up_w = _quantize_fp8(
torch.randn((H, K), device=self.device, dtype=torch.bfloat16),
self.up_scale,
)
self.down_w = _quantize_fp8(
torch.randn((N, H), device=self.device, dtype=torch.bfloat16),
self.down_scale,
)
self.up_b = torch.randn((H,), device=self.device, dtype=torch.bfloat16)
self.down_b = torch.randn((N,), device=self.device, dtype=torch.bfloat16)
self.hidden = torch.empty((M, H), device=self.device, dtype=torch.bfloat16)
self.hidden_fp8 = torch.empty((M, H), device=self.device, dtype=torch.float8_e4m3fn)
self.out = torch.empty((M, N), device=self.device, dtype=torch.bfloat16)
def _reference(self) -> torch.Tensor:
hidden = (
_dequant_fp8(self.x, self.x_scale)
@ _dequant_fp8(self.up_w, self.up_scale).T
).to(torch.bfloat16)
hidden_fp8 = _stable_gelu_quantize_fp8(
hidden, self.up_b, self.hidden_scale
)
out = (
_dequant_fp8(hidden_fp8, self.hidden_scale)
@ _dequant_fp8(self.down_w, self.down_scale).T
).to(torch.bfloat16)
return _stable_bf16_bias_add(out, self.down_b)
def setup_smoke_mlp(self) -> None:
self._setup_shape(16, 128, 256, 128)
def benchmark_smoke_mlp(self) -> None:
self.kernel.fp8_gelu_mlp_bf16(
self.x,
self.up_w,
self.up_b,
self.down_w,
self.down_b,
self.x_scale,
self.up_scale,
self.hidden_scale,
self.down_scale,
hidden_bf16=self.hidden,
hidden_fp8=self.hidden_fp8,
out=self.out,
)
def verify_smoke_mlp(self) -> torch.Tensor:
return self._reference()
|