| |
|
|
| #pragma once |
|
|
| #include "gemm/utils.h" |
|
|
| |
| |
| |
|
|
| namespace mlx { |
| namespace steel { |
|
|
| template <typename OutT, typename InT> |
| struct TransformNone { |
| static METAL_FUNC OutT apply(InT x) { |
| return static_cast<OutT>(x); |
| } |
|
|
| static METAL_FUNC OutT apply(InT x, OutT) { |
| return static_cast<OutT>(x); |
| } |
| }; |
|
|
| template <typename OutT, typename InT> |
| struct TransformAdd { |
| TransformAdd(const float, const float) {} |
|
|
| static METAL_FUNC OutT apply(InT x) { |
| return static_cast<OutT>(x); |
| } |
|
|
| static METAL_FUNC OutT apply(InT x, OutT c) { |
| return static_cast<OutT>(x) + c; |
| } |
| }; |
|
|
| template <typename OutT, typename InT> |
| struct TransformAxpby { |
| const float alpha; |
| const float beta; |
|
|
| TransformAxpby(const float alpha_, const float beta_) |
| : alpha(alpha_), beta(beta_) {} |
|
|
| static METAL_FUNC OutT apply(InT x) { |
| return static_cast<OutT>(x); |
| } |
|
|
| METAL_FUNC OutT apply(InT x, OutT c) const { |
| return static_cast<OutT>( |
| x * static_cast<InT>(alpha) + (static_cast<OutT>(beta) * c)); |
| } |
| }; |
|
|
| template <typename T> |
| struct AccumHelper { |
| typedef float accum_type; |
| }; |
|
|
| struct BlockSwizzle { |
| static METAL_FUNC int2 |
| swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { |
| const int tid_x = (tid.x) >> swizzle_log; |
| const int tid_y = |
| ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); |
| return int2(tid_x, tid_y); |
| } |
| }; |
|
|
| } |
| } |