| |
|
|
| #pragma once |
|
|
| #include "gemm/defines.h" |
|
|
| |
| |
| |
|
|
| namespace mlx { |
| namespace steel { |
|
|
| template < |
| typename T, |
| short BROWS, |
| short BCOLS, |
| short dst_ld, |
| short reduction_dim, |
| short tgp_size, |
| short alignment = 1, |
| short n_reads = (BCOLS * BROWS) / (tgp_size), |
| short TCOLS = BCOLS / n_reads, |
| short TROWS = tgp_size / TCOLS> |
| struct BlockLoader { |
| STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; |
| STEEL_CONST short vec_size = n_reads; |
|
|
| |
| const int src_ld; |
| const int tile_stride; |
|
|
| |
| const short thread_idx; |
| const short bi; |
| const short bj; |
|
|
| |
| threadgroup T* dst; |
| const device T* src; |
|
|
| struct alignas(alignment * sizeof(T)) ReadVector { |
| uint8_t v[sizeof(T) * vec_size]; |
| }; |
|
|
| |
| METAL_FUNC BlockLoader( |
| const device T* src_, |
| const int src_ld_, |
| threadgroup T* dst_, |
| ushort simd_group_id [[simdgroup_index_in_threadgroup]], |
| ushort simd_lane_id [[thread_index_in_simdgroup]]) |
| : src_ld(src_ld_), |
| tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), |
| thread_idx(simd_group_id * 32 + simd_lane_id), |
| bi(thread_idx / TCOLS), |
| bj(vec_size * (thread_idx % TCOLS)), |
| dst(dst_ + bi * dst_ld + bj), |
| src(src_ + bi * src_ld + bj) {} |
|
|
| |
| template <typename UnaryOp> |
| METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { |
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < BROWS; i += TROWS) { |
| STEEL_PRAGMA_UNROLL |
| for (short j = 0; j < vec_size; j++) { |
| dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); |
| } |
| } |
| } |
|
|
| |
| METAL_FUNC void load_unsafe() const { |
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < BROWS; i += TROWS) { |
| *((threadgroup ReadVector*)(&dst[i * dst_ld])) = |
| *((const device ReadVector*)(&src[i * src_ld])); |
| } |
| } |
|
|
| |
| METAL_FUNC void load_safe(short2 src_tile_dim) const { |
| src_tile_dim = src_tile_dim - short2(bj, bi); |
|
|
| |
| if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { |
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < BROWS; i += TROWS) { |
| STEEL_PRAGMA_UNROLL |
| for (short j = 0; j < vec_size; j++) { |
| dst[i * dst_ld + j] = T(0); |
| } |
| } |
| return; |
| } |
|
|
| |
| bool tmp_idx[vec_size]; |
| T tmp_val[vec_size]; |
|
|
| STEEL_PRAGMA_UNROLL |
| for (short i = 0; i < BROWS; i += TROWS) { |
| |
| STEEL_PRAGMA_UNROLL |
| for (short j = 0; j < vec_size; j++) { |
| tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); |
| } |
|
|
| |
| STEEL_PRAGMA_UNROLL |
| for (short j = 0; j < vec_size; j++) { |
| tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; |
| } |
|
|
| |
| STEEL_PRAGMA_UNROLL |
| for (short j = 0; j < vec_size; j++) { |
| tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); |
| } |
|
|
| |
| STEEL_PRAGMA_UNROLL |
| for (short j = 0; j < vec_size; j++) { |
| dst[i * dst_ld + j] = tmp_val[j]; |
| } |
| } |
| } |
|
|
| |
| METAL_FUNC void next() { |
| src += tile_stride; |
| } |
| }; |
|
|
| } |
| } |
|
|