| |
|
|
| #pragma once |
|
|
| #include <metal_stdlib> |
|
|
| METAL_FUNC ulong2 elem_to_loc_broadcast( |
| uint elem, |
| constant const int* shape, |
| constant const int64_t* a_strides, |
| constant const int64_t* b_strides, |
| int ndim) { |
| ulong loc_a{0}; |
| ulong loc_b{0}; |
| for (int i = ndim - 1; i >= 0 && elem > 0; --i) { |
| int pos_in_dim = (elem % shape[i]); |
| elem /= shape[i]; |
| loc_a += pos_in_dim * a_strides[i]; |
| loc_b += pos_in_dim * b_strides[i]; |
| } |
| return ulong2(loc_a, loc_b); |
| } |
|
|
| METAL_FUNC ulong3 elem_to_loc_broadcast( |
| uint elem, |
| constant const int* shape, |
| constant const int64_t* a_strides, |
| constant const int64_t* b_strides, |
| constant const int64_t* c_strides, |
| int ndim) { |
| ulong loc_a{0}; |
| ulong loc_b{0}; |
| ulong loc_c{0}; |
| for (int i = ndim - 1; i >= 0 && elem > 0; --i) { |
| int pos_in_dim = (elem % shape[i]); |
| elem /= shape[i]; |
| loc_a += pos_in_dim * a_strides[i]; |
| loc_b += pos_in_dim * b_strides[i]; |
| loc_c += pos_in_dim * c_strides[i]; |
| } |
| return ulong3(loc_a, loc_b, loc_c); |
| } |