File size: 3,585 Bytes
1f0c850
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch

from kernels.benchmark import Benchmark


def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    return torch.clamp(x.float() / scale.float(), -448.0, 448.0).to(torch.float8_e4m3fn)


def _dequant_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    return x.float() * scale.float()


def _compiler_disable(fn):
    compiler = getattr(torch, "compiler", None)
    if compiler is not None and hasattr(compiler, "disable"):
        return compiler.disable(fn)
    return torch._dynamo.disable(fn)


def _gelu_quantize_fp8_boundary(
    hidden: torch.Tensor, bias: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
    hidden = torch.nn.functional.gelu(
        hidden.float() + bias.float(), approximate="tanh"
    )
    return _quantize_fp8(hidden, scale)


def _bf16_bias_add_boundary(out: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
    return (out.float() + bias.float()).to(torch.bfloat16)


_stable_gelu_quantize_fp8 = _compiler_disable(_gelu_quantize_fp8_boundary)
_stable_bf16_bias_add = _compiler_disable(_bf16_bias_add_boundary)


class FP8GeluMlpBenchmark(Benchmark):
    seed = 17

    def _setup_shape(self, M: int, K: int, H: int, N: int) -> None:
        self.M, self.K, self.H, self.N = M, K, H, N
        self.x_scale = torch.tensor([0.05], device=self.device, dtype=torch.float32)
        self.up_scale = torch.tensor([0.04], device=self.device, dtype=torch.float32)
        self.hidden_scale = torch.tensor([0.25], device=self.device, dtype=torch.float32)
        self.down_scale = torch.tensor([0.04], device=self.device, dtype=torch.float32)
        self.x = _quantize_fp8(
            torch.randn((M, K), device=self.device, dtype=torch.bfloat16),
            self.x_scale,
        )
        self.up_w = _quantize_fp8(
            torch.randn((H, K), device=self.device, dtype=torch.bfloat16),
            self.up_scale,
        )
        self.down_w = _quantize_fp8(
            torch.randn((N, H), device=self.device, dtype=torch.bfloat16),
            self.down_scale,
        )
        self.up_b = torch.randn((H,), device=self.device, dtype=torch.bfloat16)
        self.down_b = torch.randn((N,), device=self.device, dtype=torch.bfloat16)
        self.hidden = torch.empty((M, H), device=self.device, dtype=torch.bfloat16)
        self.hidden_fp8 = torch.empty((M, H), device=self.device, dtype=torch.float8_e4m3fn)
        self.out = torch.empty((M, N), device=self.device, dtype=torch.bfloat16)

    def _reference(self) -> torch.Tensor:
        hidden = (
            _dequant_fp8(self.x, self.x_scale)
            @ _dequant_fp8(self.up_w, self.up_scale).T
        ).to(torch.bfloat16)
        hidden_fp8 = _stable_gelu_quantize_fp8(
            hidden, self.up_b, self.hidden_scale
        )
        out = (
            _dequant_fp8(hidden_fp8, self.hidden_scale)
            @ _dequant_fp8(self.down_w, self.down_scale).T
        ).to(torch.bfloat16)
        return _stable_bf16_bias_add(out, self.down_b)

    def setup_smoke_mlp(self) -> None:
        self._setup_shape(16, 128, 256, 128)

    def benchmark_smoke_mlp(self) -> None:
        self.kernel.fp8_gelu_mlp_bf16(
            self.x,
            self.up_w,
            self.up_b,
            self.down_w,
            self.down_b,
            self.x_scale,
            self.up_scale,
            self.hidden_scale,
            self.down_scale,
            hidden_bf16=self.hidden,
            hidden_fp8=self.hidden_fp8,
            out=self.out,
        )

    def verify_smoke_mlp(self) -> torch.Tensor:
        return self._reference()