danieldk HF Staff commited on
Commit
96dea35
·
verified ·
1 Parent(s): e303cc7

Build uploaded using `kernels` (batch 9/10).

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/device/reduce_split_k.h +232 -0
  3. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce.h +264 -0
  4. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h +374 -0
  5. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h +362 -0
  6. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h +267 -0
  7. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h +248 -0
  8. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h +606 -0
  9. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h +641 -0
  10. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduce.h +234 -0
  11. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduction_operators.h +235 -0
  12. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/threadblock_swizzle.h +67 -0
  13. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/relatively_equal.h +305 -0
  14. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/semaphore.h +118 -0
  15. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/subbyte_reference.h +1388 -0
  16. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_coord.h +326 -0
  17. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_ref.h +419 -0
  18. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_ref_planar_complex.h +374 -0
  19. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_view.h +297 -0
  20. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_view_planar_complex.h +302 -0
  21. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tfloat32.h +479 -0
  22. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/thread/matrix.h +198 -0
  23. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/trace.h +59 -0
  24. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp +754 -0
  25. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/device/transform_universal_adapter.hpp +303 -0
  26. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/filter_format_transformer.hpp +223 -0
  27. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp +603 -0
  28. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sparse_gemm_compressor.hpp +325 -0
  29. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/pitch_linear_thread_map.h +926 -0
  30. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/thread/transpose.h +107 -0
  31. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/thread/unary_op.h +105 -0
  32. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_iterator.h +199 -0
  33. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h +1350 -0
  34. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h +1315 -0
  35. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h +375 -0
  36. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h +328 -0
  37. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h +2118 -0
  38. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h +834 -0
  39. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h +290 -0
  40. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h +892 -0
  41. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h +1887 -0
  42. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h +787 -0
  43. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h +818 -0
  44. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h +417 -0
  45. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h +253 -0
  46. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h +58 -0
  47. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h +408 -0
  48. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h +587 -0
  49. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h +821 -0
  50. build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h +1532 -0
.gitattributes CHANGED
@@ -23,3 +23,4 @@ build/torch210-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=
23
  build/torch29-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
24
  build/torch29-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
25
  build/torch29-cxx11-cu129-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
 
 
23
  build/torch29-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
24
  build/torch29-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
25
  build/torch29-cxx11-cu129-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
26
+ build/torch29-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/device/reduce_split_k.h ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Kernel performing a reduction over densely packed tensors in global memory
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/device_kernel.h"
38
+ #include "cutlass/reduction/kernel/reduce_split_k.h"
39
+ #include "cutlass/cuda_host_adapter.hpp"
40
+ /////////////////////////////////////////////////////////////////////////////////////////////////
41
+
42
+ namespace cutlass {
43
+ namespace reduction {
44
+ namespace device {
45
+
46
+ /////////////////////////////////////////////////////////////////////////////////////////////////
47
+
48
+ template <
49
+ typename ReductionKernel_
50
+ >
51
+ class ReduceSplitK {
52
+ public:
53
+ using ReductionKernel = ReductionKernel_;
54
+
55
+ using Shape = typename ReductionKernel::Shape;
56
+ using ReductionOp = typename ReductionKernel::ReductionOp;
57
+ using OutputOp = typename ReductionKernel::OutputOp;
58
+
59
+ using ElementWorkspace = typename ReductionKernel::ElementWorkspace;
60
+ using ElementAccumulator = typename ReductionKernel::ElementAccumulator;
61
+ using ElementOutput = typename ReductionKernel::ElementOutput;
62
+
63
+ using WorkspaceTensorRef = typename ReductionKernel::WorkspaceTensorRef;
64
+ using OutputTensorRef = typename ReductionKernel::OutputTensorRef;
65
+
66
+ using StrideIndex = typename ReductionKernel::StrideIndex;
67
+
68
+ static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER;
69
+
70
+ /// Argument structure
71
+ struct Arguments {
72
+
73
+ //
74
+ // Data members
75
+ //
76
+
77
+ MatrixCoord problem_size{0,0};
78
+ int partitions{1};
79
+ size_t partition_stride{0};
80
+ WorkspaceTensorRef workspace{};
81
+ OutputTensorRef destination{};
82
+ OutputTensorRef source{};
83
+ typename OutputOp::Params output{};
84
+ typename ReductionOp::Params reduction{};
85
+
86
+ //
87
+ // Methods
88
+ //
89
+
90
+ /// Default ctor
91
+ Arguments() = default;
92
+
93
+ CUTLASS_HOST_DEVICE
94
+ Arguments(
95
+ MatrixCoord const & problem_size
96
+ ):
97
+ problem_size(problem_size) { }
98
+
99
+ CUTLASS_HOST_DEVICE
100
+ Arguments(
101
+ MatrixCoord problem_size_,
102
+ int partitions_,
103
+ size_t partition_stride_,
104
+ WorkspaceTensorRef workspace_,
105
+ OutputTensorRef destination_,
106
+ OutputTensorRef source_,
107
+ typename OutputOp::Params output_ = typename OutputOp::Params(),
108
+ typename ReductionOp::Params reduction_ = typename ReductionOp::Params()
109
+ ):
110
+ problem_size(problem_size_),
111
+ partitions(partitions_),
112
+ partition_stride(partition_stride_),
113
+ workspace(workspace_),
114
+ destination(destination_),
115
+ source(source_),
116
+ output(output_),
117
+ reduction(reduction_)
118
+ {
119
+
120
+ }
121
+
122
+ };
123
+
124
+ private:
125
+ /// Kernel parameters object
126
+ typename ReductionKernel::Params params_;
127
+
128
+ public:
129
+ /// Constructs Reduction SplitK
130
+ ReduceSplitK() { }
131
+
132
+ /// Determines whether the ReduceSplitK can execute the given problem.
133
+ static Status can_implement(Arguments const &args) {
134
+
135
+ return Status::kSuccess;
136
+ }
137
+
138
+ /// Gets the workspace size
139
+ static size_t get_workspace_size(Arguments const &args) {
140
+ // needs no additional workspace
141
+ return 0;
142
+ }
143
+
144
+ /// Initializes Reduction state from arguments.
145
+ Status initialize(
146
+ Arguments const &args,
147
+ void *workspace = nullptr,
148
+ cudaStream_t stream = nullptr) {
149
+
150
+ // initialize the params structure from the arguments
151
+ params_ = typename ReductionKernel::Params(
152
+ args.problem_size,
153
+ args.partitions,
154
+ args.partition_stride,
155
+ args.workspace,
156
+ args.destination,
157
+ args.source,
158
+ args.output,
159
+ args.reduction
160
+ );
161
+
162
+ return Status::kSuccess;
163
+
164
+ }
165
+
166
+ /// Initializes Reduction kernel state from arguments.
167
+ Status update(Arguments const &args, void *workspace = nullptr) {
168
+
169
+ // update the params structure from the arguments
170
+ params_.workspace.reset(args.workspace.non_const_ref().data());
171
+ params_.destination.reset(args.destination.non_const_ref().data());
172
+ params_.source.reset(args.source.non_const_ref().data());
173
+ params_.output = args.output;
174
+ params_.reduction = args.reduction;
175
+
176
+ return Status::kSuccess;
177
+ }
178
+
179
+ /// Runs the kernel using initialized state.
180
+ Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) {
181
+
182
+ //
183
+ // Launch reduction kernel
184
+ //
185
+ dim3 block = ReductionKernel::block_shape();
186
+ dim3 grid = ReductionKernel::grid_shape(params_.problem_size);
187
+
188
+ if constexpr (kEnableCudaHostAdapter) {
189
+ CUTLASS_ASSERT(cuda_adapter);
190
+ if (cuda_adapter) {
191
+ void* kernel_params[] = {&params_};
192
+ cuda_adapter->launch(
193
+ grid, dim3(1,1,1), block, 0, stream, kernel_params, kernel_index);
194
+ }
195
+ }
196
+ else {
197
+ cutlass::arch::synclog_setup();
198
+ Kernel<ReductionKernel><<< grid, block, 0, stream >>>(params_);
199
+ }
200
+
201
+ cudaError_t result = cudaGetLastError();
202
+ return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
203
+ }
204
+
205
+
206
+ /// Runs the kernel using initialized state.
207
+ Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) {
208
+ return run(stream, cuda_adapter, kernel_index);
209
+ }
210
+
211
+ /// Runs the kernel using initialized state.
212
+ Status operator()(
213
+ Arguments const &args,
214
+ void *workspace = nullptr,
215
+ cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) {
216
+
217
+ Status status = initialize(args, workspace, stream);
218
+
219
+ if (status == Status::kSuccess) {
220
+ status = run(stream,cuda_adapter, kernel_index);
221
+ }
222
+
223
+ return status;
224
+ }
225
+
226
+ };
227
+
228
+ /////////////////////////////////////////////////////////////////////////////////////////////////
229
+
230
+ } // namespace kernel
231
+ } // namespace reduction
232
+ } // namespace cutlass
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce.h ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Kernel performing a reduction over one or more ranks of an affine tensor
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/array.h"
39
+ #include "cutlass/fast_math.h"
40
+ #include "cutlass/numeric_types.h"
41
+ #include "cutlass/numeric_conversion.h"
42
+ #include "cutlass/device_kernel.h"
43
+
44
+ #include "cutlass/reduction/device/tensor_reduce_affine_strided.h"
45
+ #include "cutlass/reduction/device/tensor_reduce_affine_contiguous.h"
46
+
47
+ /////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ namespace cutlass {
50
+ namespace reduction {
51
+ namespace device {
52
+
53
+ /////////////////////////////////////////////////////////////////////////////////////////////////
54
+
55
+ /// Tensor reduction operator on specific CUTLASS layouts over exactly one index
56
+ template <
57
+ typename ElementOutput_,
58
+ typename ElementSource_,
59
+ typename Layout_,
60
+ typename ReductionOp_,
61
+ int VectorLength_ = 1,
62
+ typename ElementCompute_ = ElementOutput_
63
+ >
64
+ struct TensorReduction {
65
+
66
+ using ElementOutput = ElementOutput_;
67
+ using ElementSource = ElementSource_;
68
+ using Layout = Layout_;
69
+ using ReductionOp = ReductionOp_;
70
+ static int const kVectorLength = VectorLength_;
71
+ using ElementCompute = ElementCompute_;
72
+
73
+ using TensorCoord = typename Layout::TensorCoord;
74
+
75
+ /// Reduction operator
76
+ using ReductionDeviceStridedOperator = TensorReductionAffineStrided<
77
+ 4, 3, ElementOutput, ElementSource, ReductionOp, kVectorLength, ElementCompute
78
+ >;
79
+
80
+ using ReductionDeviceContiguousOperator = TensorReductionAffineContiguous<
81
+ 4, 3, ElementOutput, ElementSource, ReductionOp, kVectorLength, ElementCompute
82
+ >;
83
+
84
+ //
85
+ // Data members
86
+ //
87
+
88
+ ReductionDeviceStridedOperator reduction_strided;
89
+ ReductionDeviceContiguousOperator reduction_contiguous;
90
+ int reduction_index;
91
+
92
+ //
93
+ // Methods
94
+ //
95
+
96
+ ///
97
+ TensorReduction(
98
+ TensorCoord extent,
99
+ int reduction_index_
100
+ ):
101
+ reduction_index(reduction_index_) {
102
+
103
+ Coord<4> extent_affine;
104
+
105
+ switch (reduction_index) {
106
+ case 0:
107
+ extent_affine[0] = extent[1];
108
+ extent_affine[1] = extent[2];
109
+ extent_affine[2] = extent[0];
110
+ extent_affine[3] = extent[3];
111
+ break;
112
+ case 1:
113
+ extent_affine[0] = extent[0];
114
+ extent_affine[1] = extent[2];
115
+ extent_affine[2] = extent[1];
116
+ extent_affine[3] = extent[3];
117
+ break;
118
+ case 2:
119
+ extent_affine[0] = extent[0];
120
+ extent_affine[1] = extent[1];
121
+ extent_affine[2] = extent[2];
122
+ extent_affine[3] = extent[3];
123
+ break;
124
+ case 3:
125
+ extent_affine[0] = extent[0];
126
+ extent_affine[1] = extent[1];
127
+ extent_affine[2] = extent[2];
128
+ extent_affine[3] = extent[3];
129
+ break;
130
+ default: break;
131
+ }
132
+
133
+ if (reduction_index == 3) {
134
+ reduction_contiguous = ReductionDeviceContiguousOperator(extent_affine);
135
+ }
136
+ else {
137
+ reduction_strided = ReductionDeviceStridedOperator(extent_affine);
138
+ }
139
+ }
140
+
141
+ /// Simple check to verify the object is initialized correctly
142
+ bool good() const {
143
+ if (reduction_index == 3) {
144
+ return reduction_contiguous.good();
145
+ }
146
+ return reduction_strided.good();
147
+ }
148
+
149
+ /// Size of one workspace
150
+ int64_t workspace_stride() const {
151
+ if (reduction_index == 3) {
152
+ return reduction_contiguous.workspace_stride();
153
+ }
154
+ else {
155
+ return reduction_strided.workspace_stride();
156
+ }
157
+ }
158
+
159
+ /// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs
160
+ int64_t workspace_size() const {
161
+ if (reduction_index == 3) {
162
+ return reduction_contiguous.workspace_size();
163
+ }
164
+ else {
165
+ return reduction_strided.workspace_size();
166
+ }
167
+ }
168
+
169
+ /// Helper to use overloaded function call operator
170
+ Status reduce(
171
+ TensorRef<ElementOutput, Layout> dst_ref,
172
+ TensorRef<ElementSource, Layout> src_ref,
173
+ void *device_workspace_ptr = nullptr,
174
+ ElementCompute reduction_identity = ElementCompute(),
175
+ ReductionOp reduction_op = ReductionOp(),
176
+ cudaStream_t stream = nullptr) {
177
+
178
+ int64_t src_stride[3];
179
+ int64_t dst_stride[3];
180
+
181
+ switch (reduction_index) {
182
+ case 0:
183
+ src_stride[0] = src_ref.stride()[1];
184
+ src_stride[1] = src_ref.stride()[0];
185
+ src_stride[2] = src_ref.stride()[2];
186
+ dst_stride[0] = dst_ref.stride()[1];
187
+ dst_stride[1] = dst_ref.stride()[0];
188
+ break;
189
+ case 1:
190
+ src_stride[0] = src_ref.stride()[2];
191
+ src_stride[1] = src_ref.stride()[0];
192
+ src_stride[2] = src_ref.stride()[1];
193
+ dst_stride[0] = dst_ref.stride()[2];
194
+ dst_stride[1] = dst_ref.stride()[0];
195
+ break;
196
+ case 2:
197
+ src_stride[0] = src_ref.stride()[2];
198
+ src_stride[1] = src_ref.stride()[1];
199
+ src_stride[2] = src_ref.stride()[0];
200
+ dst_stride[0] = dst_ref.stride()[2];
201
+ dst_stride[1] = dst_ref.stride()[1];
202
+ break;
203
+ case 3:
204
+ src_stride[0] = src_ref.stride()[2];
205
+ src_stride[1] = src_ref.stride()[1];
206
+ src_stride[2] = src_ref.stride()[0];
207
+
208
+ dst_stride[0] = dst_ref.stride()[2];
209
+ dst_stride[1] = dst_ref.stride()[1];
210
+ dst_stride[2] = dst_ref.stride()[0];
211
+
212
+ default: break;
213
+ }
214
+
215
+ if (reduction_index == 3) {
216
+ return reduction_contiguous(
217
+ dst_ref.data(),
218
+ dst_stride,
219
+ src_ref.data(),
220
+ src_stride,
221
+ device_workspace_ptr,
222
+ reduction_identity,
223
+ reduction_op,
224
+ stream);
225
+ }
226
+ else {
227
+ return reduction_strided(
228
+ dst_ref.data(),
229
+ dst_stride,
230
+ src_ref.data(),
231
+ src_stride,
232
+ device_workspace_ptr,
233
+ reduction_identity,
234
+ reduction_op,
235
+ stream);
236
+ }
237
+ }
238
+
239
+ Status operator()(
240
+ TensorRef<ElementOutput, Layout> dst_ref,
241
+ TensorRef<ElementSource, Layout> src_ref,
242
+ void *device_workspace_ptr = nullptr,
243
+ ElementCompute reduction_identity = ElementCompute(),
244
+ ReductionOp reduction_op = ReductionOp(),
245
+ cudaStream_t stream = nullptr) {
246
+
247
+ return reduce(
248
+ dst_ref,
249
+ src_ref,
250
+ device_workspace_ptr,
251
+ reduction_identity,
252
+ reduction_op,
253
+ stream);
254
+ }
255
+ };
256
+
257
+ /////////////////////////////////////////////////////////////////////////////////////////////////
258
+
259
+ } // namespace device
260
+ } // namespace reduction
261
+ } // namespace cutlass
262
+
263
+ /////////////////////////////////////////////////////////////////////////////////////////////////
264
+
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Kernel performing a reduction over one or more ranks of an affine tensor
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/array.h"
39
+ #include "cutlass/fast_math.h"
40
+ #include "cutlass/numeric_types.h"
41
+ #include "cutlass/numeric_conversion.h"
42
+ #include "cutlass/device_kernel.h"
43
+
44
+ #include "cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h"
45
+
46
+ /////////////////////////////////////////////////////////////////////////////////////////////////
47
+
48
+ namespace cutlass {
49
+ namespace reduction {
50
+ namespace device {
51
+
52
+ /////////////////////////////////////////////////////////////////////////////////////////////////
53
+
54
+ /// Tensor reduction operator on layouts which are affine
55
+ template <
56
+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5)
57
+ int ReducedRank, ///< Rank of reduced tensor (e.g. ND => 2)
58
+ typename ElementOutput_,
59
+ typename ElementSource_,
60
+ typename ReductionOp_,
61
+ int VectorLength = 1,
62
+ typename ElementCompute_ = ElementOutput_,
63
+ int Threads = 256, ///< Number of participating threads
64
+ int BatchSize = 4 ///< Number of elements to load per batch
65
+ >
66
+ struct TensorReductionAffineContiguous {
67
+
68
+ static int const kRank = Rank;
69
+ static int const kReducedRank = ReducedRank;
70
+ static int const kVectorLength = VectorLength;
71
+ static int const kInnerRank = kRank - kReducedRank;
72
+ static int const kThreads = Threads;
73
+ static int const kBatchSize = BatchSize;
74
+
75
+ using ElementOutput = ElementOutput_;
76
+ using ElementSource = ElementSource_;
77
+ using ReductionOp = ReductionOp_;
78
+ using ElementCompute = ElementCompute_;
79
+
80
+ //
81
+ // Data members
82
+ //
83
+
84
+ /// Internal status field
85
+ Status status;
86
+
87
+ /// Extent of tensor in source layout
88
+ Coord<kRank> extent;
89
+
90
+ /// Number of points in the outer index space
91
+ int64_t outer_count;
92
+
93
+ /// Number of elements in the inner index space
94
+ int64_t inner_count;
95
+
96
+ /// Number of workspaces needed
97
+ int workspace_count;
98
+
99
+ /// CUDA Grid shape (.x => contiguous, .y => outer, .z => inner)
100
+ dim3 grid_shape;
101
+
102
+ /// CUDA Threadblock shape (.x => contiguous, .y => outer, .z => inner)
103
+ dim3 threadblock_shape;
104
+
105
+ /// CUDA grid shape for the final reduction step if needed
106
+ dim3 grid_final;
107
+
108
+ /// CUDA threadblock shape for the final reduction step if needed
109
+ dim3 threadblock_final;
110
+
111
+ private:
112
+ //
113
+ // Methods
114
+ //
115
+
116
+ /// Helper to reshape 'count' such that it is less than 2 x 'ext'
117
+ static int reshape_pow2(int ext, int count) {
118
+ if (ext > count) {
119
+ return 1;
120
+ }
121
+ int x = 1;
122
+ for (; count >= ext * 2; ) {
123
+ count >>= 1;
124
+ x <<= 1;
125
+ }
126
+ return x;
127
+ }
128
+
129
+ public:
130
+
131
+ /// Default ctor
132
+ TensorReductionAffineContiguous():
133
+ status(Status::kErrorInvalidProblem),
134
+ extent(),
135
+ outer_count(0),
136
+ inner_count(0),
137
+ workspace_count(0),
138
+ grid_shape(0, 0, 0),
139
+ threadblock_shape(0, 0, 0) { }
140
+
141
+ /// Constructor
142
+ TensorReductionAffineContiguous(
143
+ Coord<kRank> extent_,
144
+ int target_threadblock_count = 128
145
+ ):
146
+ status(Status::kSuccess),
147
+ extent(extent_),
148
+ outer_count(0),
149
+ inner_count(0),
150
+ workspace_count(0) {
151
+
152
+ //
153
+ // Plan the parallel mapping strategy.
154
+ //
155
+
156
+ outer_count = 1;
157
+ inner_count = 1;
158
+
159
+ // Compute number of elements in strided ranks
160
+ for (int p = 0; p < kReducedRank; ++p) {
161
+ outer_count *= extent[p];
162
+ }
163
+
164
+ for (int p = 0; p < kInnerRank; ++p) {
165
+ inner_count *= extent[kReducedRank + p];
166
+ }
167
+
168
+ int cta_count_x = 1;
169
+ int cta_count_y = 1;
170
+ int cta_count_z = 1;
171
+
172
+ int cta_threads_x = kThreads;
173
+ int cta_threads_y = 1;
174
+ int cta_threads_z = 1;
175
+
176
+ // Determine CTA shape
177
+ int64_t inner_vector_count = inner_count / kVectorLength;
178
+
179
+ // Priority 1. Assign threadblocks to outer indices if possible
180
+ if (outer_count > target_threadblock_count) {
181
+ cta_count_x = 1;
182
+ cta_count_y = target_threadblock_count;
183
+ cta_count_z = 1;
184
+ }
185
+ else {
186
+
187
+ cta_count_y = int(outer_count);
188
+ int remaining_ctas = target_threadblock_count / cta_count_y;
189
+
190
+ // Priority 2. Assign inner dimensions to one CTA
191
+ if (inner_vector_count > cta_threads_x) {
192
+ int64_t cta_z_bound = inner_vector_count / cta_threads_x;
193
+ if (cta_z_bound > remaining_ctas) {
194
+ cta_count_z = remaining_ctas;
195
+ }
196
+ else {
197
+ cta_count_z = int(cta_z_bound);
198
+ }
199
+ }
200
+ else {
201
+ cta_threads_x = reshape_pow2(int(inner_vector_count), cta_threads_x);
202
+ cta_count_z = 1;
203
+ }
204
+ }
205
+
206
+ grid_shape = dim3(cta_count_x, cta_count_y, cta_count_z);
207
+ threadblock_shape = dim3(cta_threads_x, cta_threads_y, cta_threads_z);
208
+
209
+ workspace_count = (cta_count_z > 1 ? cta_count_z : 0);
210
+
211
+ // Determine shape of final reduction kernel if needed
212
+ if (workspace_count) {
213
+
214
+ int final_threads = kThreads;
215
+ int final_ctas = 1;
216
+
217
+ if (outer_count > kThreads) {
218
+ final_ctas = int(outer_count + kThreads - 1) / kThreads;
219
+ }
220
+ else {
221
+ final_threads = int(outer_count);
222
+ }
223
+
224
+ grid_final = dim3(final_ctas, 1, 1);
225
+ threadblock_final = dim3(final_threads, 1, 1);
226
+ }
227
+ else {
228
+ grid_final = dim3(0, 0, 0);
229
+ threadblock_final = dim3(0, 0, 0);
230
+ }
231
+ }
232
+
233
+ /// Simple check to verify the object is initialized correctly
234
+ bool good() const {
235
+ return status == Status::kSuccess;
236
+ }
237
+
238
+ /// Size (in bytes) of <outer_count> workspace elements which are densely packed together
239
+ int64_t workspace_stride() const {
240
+
241
+ // Error condition
242
+ if (!good()) {
243
+ return 0;
244
+ }
245
+
246
+ return outer_count * sizeof_bits<ElementCompute>::value / 8;
247
+ }
248
+
249
+ /// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs
250
+ int64_t workspace_size() const {
251
+
252
+ // Error condition
253
+ if (!good()) {
254
+ return 0;
255
+ }
256
+
257
+ // No reduction across CTAs
258
+ if (grid_shape.z == 1) {
259
+ return 0;
260
+ }
261
+
262
+ return workspace_stride() * grid_shape.z;
263
+ }
264
+
265
+ /// Performs a reduction
266
+ Status reduce(
267
+ ElementOutput *dst_ptr, ///< Pointer to destination tensor
268
+ int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1)
269
+ ElementSource const *src_ptr, ///< Pointer to source tensor
270
+ int64_t src_stride[], ///< Stride vector (of length kRank - 1)
271
+ void *device_workspace_ptr = nullptr, ///< Device workspace
272
+ ElementCompute reduction_identity = ElementCompute(), ///< Reduction identity element
273
+ ReductionOp reduction_op = ReductionOp(), ///< Reduction operator
274
+ cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched
275
+
276
+ // Initial status check
277
+ if (!good()) {
278
+ return status;
279
+ }
280
+
281
+ // Guard against null workspace
282
+ if (workspace_count > 1 && device_workspace_ptr == nullptr) {
283
+ return Status::kErrorWorkspaceNull;
284
+ }
285
+
286
+ // Define reduction kernel
287
+ using ReductionKernel = kernel::TensorReductionAffineContiguous<
288
+ kRank,
289
+ kReducedRank,
290
+ ElementOutput,
291
+ ElementSource,
292
+ ReductionOp,
293
+ kVectorLength,
294
+ ElementCompute,
295
+ kThreads>;
296
+
297
+ using FinalReductionKernel = kernel::TensorReductionAffineContiguousFinal<
298
+ kRank,
299
+ kReducedRank,
300
+ ElementOutput,
301
+ ElementSource,
302
+ ReductionOp,
303
+ kVectorLength,
304
+ ElementCompute,
305
+ kThreads>;
306
+
307
+ using Params = typename ReductionKernel::Params;
308
+
309
+ // Construct the parameters
310
+ Params params(
311
+ extent,
312
+ dst_ptr,
313
+ dst_stride,
314
+ src_ptr,
315
+ src_stride,
316
+ static_cast<ElementCompute *>(device_workspace_ptr),
317
+ workspace_stride(),
318
+ workspace_count,
319
+ reduction_op,
320
+ reduction_identity);
321
+
322
+ // Shared memory size
323
+ int shared_mem_bytes = sizeof(typename ReductionKernel::SharedStorage);
324
+
325
+ // Launch the kernel
326
+ cutlass::arch::synclog_setup();
327
+ Kernel<ReductionKernel><<< grid_shape, threadblock_shape, shared_mem_bytes, stream >>>(params);
328
+
329
+ // Check error condition
330
+ if (cudaPeekAtLastError() == cudaSuccess) {
331
+ status = Status::kSuccess;
332
+ }
333
+ else {
334
+ status = Status::kErrorInternal;
335
+ }
336
+
337
+ // Final reduction kernel
338
+ if (workspace_count) {
339
+ Kernel<FinalReductionKernel><<< grid_final, threadblock_final, 0, stream >>>(params);
340
+ }
341
+
342
+ // Check error condition
343
+ if (cudaPeekAtLastError() == cudaSuccess) {
344
+ status = Status::kSuccess;
345
+ }
346
+ else {
347
+ status = Status::kErrorInternal;
348
+ }
349
+
350
+ return status;
351
+ }
352
+
353
+ /// Helper to use overloaded function call operator
354
+ Status operator()(
355
+ ElementOutput *dst_ptr, ///< Pointer to destination tensor
356
+ int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1)
357
+ ElementSource const *src_ptr, ///< Pointer to source tensor
358
+ int64_t src_stride[], ///< Stride vector (of length kRank - 1)
359
+ void *device_workspace_ptr = nullptr, ///< Pointer to device workspace
360
+ ElementCompute reduction_identity = ElementCompute(), ///< Reduction identity element
361
+ ReductionOp reduction_op = ReductionOp(), ///< Reduction operator
362
+ cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched
363
+
364
+ return reduce(dst_ptr, dst_stride, src_ptr, src_stride, device_workspace_ptr, reduction_identity, reduction_op, stream);
365
+ }
366
+ };
367
+
368
+ /////////////////////////////////////////////////////////////////////////////////////////////////
369
+
370
+ } // namespace device
371
+ } // namespace reduction
372
+ } // namespace cutlass
373
+
374
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Kernel performing a reduction over one or more ranks of an affine tensor
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/array.h"
39
+ #include "cutlass/fast_math.h"
40
+ #include "cutlass/numeric_types.h"
41
+ #include "cutlass/numeric_conversion.h"
42
+ #include "cutlass/device_kernel.h"
43
+
44
+ #include "cutlass/reduction/kernel/tensor_reduce_affine_strided.h"
45
+
46
+ /////////////////////////////////////////////////////////////////////////////////////////////////
47
+
48
+ namespace cutlass {
49
+ namespace reduction {
50
+ namespace device {
51
+
52
+ /////////////////////////////////////////////////////////////////////////////////////////////////
53
+
54
+ /// Tensor reduction operator on layouts which are affine
55
+ template <
56
+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5)
57
+ int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2)
58
+ typename ElementOutput_,
59
+ typename ElementSource_,
60
+ typename ReductionOp_,
61
+ int VectorLength = 1,
62
+ typename ElementCompute_ = ElementOutput_,
63
+ int Threads = 256, ///< Number of participating threads
64
+ int BatchSize = 4 ///< Number of elements to load per batch
65
+ >
66
+ struct TensorReductionAffineStrided {
67
+
68
+ static int const kRank = Rank;
69
+ static int const kReducedRank = ReducedRank;
70
+ static int const kVectorLength = VectorLength;
71
+ static int const kInnerRank = kRank - kReducedRank;
72
+ static int const kThreads = Threads;
73
+ static int const kBatchSize = BatchSize;
74
+
75
+ using ElementOutput = ElementOutput_;
76
+ using ElementSource = ElementSource_;
77
+ using ReductionOp = ReductionOp_;
78
+ using ElementCompute = ElementCompute_;
79
+
80
+ //
81
+ // Data members
82
+ //
83
+
84
+ /// Internal status field
85
+ Status status;
86
+
87
+ /// Extent of tensor in source layout
88
+ Coord<kRank> extent;
89
+
90
+ /// Number of points in the outer index space
91
+ int64_t outer_count;
92
+
93
+ /// Number of elements in the inner index space
94
+ int64_t inner_count;
95
+
96
+ /// Number of workspaces needed
97
+ int workspace_count;
98
+
99
+ /// CUDA Grid shape (.x => contiguous, .y => outer, .z => inner)
100
+ dim3 grid_shape;
101
+
102
+ /// CUDA Threadblock shape (.x => contiguous, .y => outer, .z => inner)
103
+ dim3 threadblock_shape;
104
+
105
+ /// CUDA grid shape for the final reduction step if needed
106
+ dim3 grid_final;
107
+
108
+ /// CUDA threadblock shape for the final reduction step if needed
109
+ dim3 threadblock_final;
110
+
111
+ private:
112
+ //
113
+ // Methods
114
+ //
115
+
116
+ /// Helper to reshape 'count' such that it is less than 2 x 'ext'
117
+ static int reshape_pow2(int ext, int count) {
118
+ if (ext > count) {
119
+ return 1;
120
+ }
121
+ int x = 1;
122
+ for (; count >= ext * 2; ) {
123
+ count >>= 1;
124
+ x <<= 1;
125
+ }
126
+ return x;
127
+ }
128
+
129
+ public:
130
+
131
+ /// Default ctor
132
+ TensorReductionAffineStrided():
133
+ status(Status::kErrorInvalidProblem),
134
+ extent(),
135
+ outer_count(0),
136
+ inner_count(0),
137
+ workspace_count(0),
138
+ grid_shape(0, 0, 0),
139
+ threadblock_shape(0, 0, 0) { }
140
+
141
+ /// Constructor
142
+ TensorReductionAffineStrided(
143
+ Coord<kRank> extent_,
144
+ int target_threadblock_count = 128
145
+ ):
146
+ status(Status::kSuccess),
147
+ extent(extent_),
148
+ outer_count(0),
149
+ inner_count(0),
150
+ workspace_count(0) {
151
+
152
+ //
153
+ // Plan the parallel mapping strategy.
154
+ //
155
+
156
+ outer_count = 1;
157
+ inner_count = 1;
158
+
159
+ // Compute number of elements in strided ranks
160
+ for (int p = 0; p < kReducedRank - 1; ++p) {
161
+ outer_count *= extent[p];
162
+ }
163
+
164
+ for (int p = 0; p < kInnerRank; ++p) {
165
+ inner_count *= extent[kReducedRank + p - 1];
166
+ }
167
+
168
+ // Compute plan for the reduction
169
+ int extent_c = extent[kRank - 1];
170
+ int vectors_c = (extent_c -1 + kVectorLength) / kVectorLength;
171
+
172
+ // Determine CTA shape
173
+ int cta_width = kThreads * kVectorLength;
174
+ int cta_ways = reshape_pow2(extent_c, cta_width);
175
+ int cta_threads_x = kThreads / cta_ways;
176
+
177
+ threadblock_shape = dim3(cta_threads_x, 1, std::min(cta_ways, 64));
178
+
179
+ // This leads to an error.
180
+ if (threadblock_shape.z > 1) {
181
+ if (threadblock_shape.y != 1) {
182
+ status = Status::kErrorInternal;
183
+ return;
184
+ }
185
+ }
186
+
187
+ // Determine grid shape
188
+ int cta_count_x = (vectors_c + cta_threads_x - 1) / cta_threads_x;
189
+ int cta_count_y = std::max(1, target_threadblock_count / cta_count_x);
190
+
191
+ // Limit the number of CTAs assigned to outer dimension
192
+ if (int64_t(cta_count_y * threadblock_shape.y) > outer_count) {
193
+ cta_count_y = int(outer_count + threadblock_shape.y - 1) / threadblock_shape.y;
194
+ }
195
+
196
+ // Limit the number of CTAs assigned to inner dimension
197
+ int cta_count_z = std::max(1, target_threadblock_count / cta_count_y);
198
+ if (int64_t(cta_count_z * threadblock_shape.z) > inner_count) {
199
+ cta_count_z = int(inner_count + threadblock_shape.z - 1) / threadblock_shape.z;
200
+ }
201
+
202
+ grid_shape = dim3(cta_count_x, cta_count_y, cta_count_z);
203
+ workspace_count = (cta_count_z > 1 ? cta_count_z : 0);
204
+
205
+ // Determine shape of final reduction kernel if needed
206
+ grid_final = dim3(cta_count_x, int(outer_count));
207
+ threadblock_final = dim3(cta_threads_x, 1, 1);
208
+ }
209
+
210
+ /// Simple check to verify the object is initialized correctly
211
+ bool good() const {
212
+ return status == Status::kSuccess;
213
+ }
214
+
215
+ /// Size of one CTA's workspace
216
+ int64_t workspace_stride() const {
217
+
218
+ // Error condition
219
+ if (!good()) {
220
+ return 0;
221
+ }
222
+
223
+ int vector_size_bytes = kVectorLength * sizeof_bits<ElementCompute>::value / 8;
224
+
225
+ return extent[kRank - 1] * vector_size_bytes;
226
+ }
227
+
228
+ /// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs
229
+ int64_t workspace_size() const {
230
+
231
+ // Error condition
232
+ if (!good()) {
233
+ return 0;
234
+ }
235
+
236
+ // No reduction across CTAs
237
+ if (grid_shape.z == 1) {
238
+ return 0;
239
+ }
240
+
241
+ return workspace_stride() * outer_count * grid_shape.z;
242
+ }
243
+
244
+ /// Performs a reduction
245
+ Status reduce(
246
+ ElementOutput *dst_ptr, ///< Pointer to destination tensor
247
+ int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1)
248
+ ElementSource const *src_ptr, ///< Pointer to source tensor
249
+ int64_t src_stride[], ///< Stride vector (of length kRank - 1)
250
+ void *device_workspace_ptr = nullptr, ///< Device workspace
251
+ ElementCompute reduction_identity = ElementCompute(), ///< Reduciton identity
252
+ ReductionOp reduction_op = ReductionOp(), ///< Reduction operator
253
+ cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched
254
+
255
+ // Initial status check
256
+ if (!good()) {
257
+ return status;
258
+ }
259
+
260
+ // Guard against null workspace
261
+ if (workspace_count > 1 && device_workspace_ptr == nullptr) {
262
+ return Status::kErrorWorkspaceNull;
263
+ }
264
+
265
+ // Define reduction kernel
266
+ using ReductionKernel = kernel::TensorReductionAffineStrided<
267
+ kRank,
268
+ kReducedRank,
269
+ ElementOutput,
270
+ ElementSource,
271
+ ReductionOp,
272
+ kVectorLength,
273
+ ElementCompute,
274
+ kThreads>;
275
+
276
+ using FinalReductionKernel = kernel::TensorReductionAffineStridedFinal<
277
+ kRank,
278
+ kReducedRank,
279
+ ElementOutput,
280
+ ElementSource,
281
+ ReductionOp,
282
+ kVectorLength,
283
+ ElementCompute,
284
+ kThreads>;
285
+
286
+ using Params = typename ReductionKernel::Params;
287
+
288
+ // Construct the parameters
289
+ Params params(
290
+ extent,
291
+ dst_ptr,
292
+ dst_stride,
293
+ src_ptr,
294
+ src_stride,
295
+ static_cast<ElementCompute *>(device_workspace_ptr),
296
+ workspace_stride(),
297
+ workspace_count,
298
+ reduction_op,
299
+ reduction_identity);
300
+
301
+ // Shared memory size
302
+ int shared_mem_bytes = sizeof(typename ReductionKernel::SharedStorage);
303
+
304
+ // Launch the kernel
305
+ cutlass::arch::synclog_setup();
306
+ Kernel<ReductionKernel><<< grid_shape, threadblock_shape, shared_mem_bytes, stream >>>(params);
307
+
308
+ // Check error condition
309
+ if (cudaPeekAtLastError() == cudaSuccess) {
310
+ status = Status::kSuccess;
311
+ }
312
+ else {
313
+ status = Status::kErrorInternal;
314
+ }
315
+
316
+ // Final reduction kernel
317
+ if (workspace_count) {
318
+
319
+ Kernel<FinalReductionKernel><<< grid_final, threadblock_final, 0, stream >>>(params);
320
+
321
+ // Check error condition
322
+ if (cudaPeekAtLastError() == cudaSuccess) {
323
+ status = Status::kSuccess;
324
+ }
325
+ else {
326
+ status = Status::kErrorInternal;
327
+ }
328
+ }
329
+
330
+ return status;
331
+ }
332
+
333
+ /// Helper to use overloaded function call operator
334
+ Status operator()(
335
+ ElementOutput *dst_ptr, ///< Pointer to destination tensor
336
+ int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1)
337
+ ElementSource const *src_ptr, ///< Pointer to source tensor
338
+ int64_t src_stride[], ///< Stride vector (of length kRank - 1)
339
+ void *device_workspace_ptr = nullptr, ///< Pointer to device workspace
340
+ ElementCompute reduction_identity = ElementCompute(), ///< Reduciton identity
341
+ ReductionOp reduction_op = ReductionOp(), ///< Reduction operator
342
+ cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched
343
+
344
+ return reduce(
345
+ dst_ptr,
346
+ dst_stride,
347
+ src_ptr,
348
+ src_stride,
349
+ device_workspace_ptr,
350
+ reduction_identity,
351
+ reduction_op,
352
+ stream);
353
+ }
354
+ };
355
+
356
+ /////////////////////////////////////////////////////////////////////////////////////////////////
357
+
358
+ } // namespace device
359
+ } // namespace reduction
360
+ } // namespace cutlass
361
+
362
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Kernel performing a final reduction for softmax
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/numeric_types.h"
39
+ #include "cutlass/array.h"
40
+ #include "cutlass/functional.h"
41
+ #include "cutlass/matrix_shape.h"
42
+ #include "cutlass/numeric_conversion.h"
43
+ #include "cutlass/arch/memory.h"
44
+ #include "cutlass/arch/memory_sm75.h"
45
+
46
+ /////////////////////////////////////////////////////////////////////////////////////////////////
47
+
48
+ namespace cutlass {
49
+ namespace reduction {
50
+ namespace kernel {
51
+
52
+ template <
53
+ typename ElementNorm_,
54
+ typename ElementSum_,
55
+ typename ElementSoftmaxCompute_,
56
+ typename ThreadblockShape_,
57
+ bool GroupedProblem = false
58
+ >
59
+ class ApplySoftmaxFinalReduction {
60
+ public:
61
+
62
+ using ElementNorm = ElementNorm_;
63
+ using ElementSum = ElementSum_;
64
+ using ElementSoftmaxCompute = ElementSoftmaxCompute_;
65
+ using ThreadblockShape = ThreadblockShape_;
66
+ static const bool isGroupedProblem = GroupedProblem;
67
+
68
+ //
69
+ // Arguments
70
+ //
71
+
72
+ struct Arguments {
73
+
74
+ cutlass::gemm::GemmCoord* problem_sizes{nullptr};
75
+ cutlass::gemm::GemmCoord problem_size{};
76
+ ElementNorm* block_Norm{nullptr};
77
+ ElementSum* block_Sum{nullptr};
78
+ int64_t* offset_Norm_Device{nullptr};
79
+ int64_t* offset_Sum_Device{nullptr};
80
+ int64_t batch_stride_Max{0};
81
+ int64_t batch_stride_Sum{0};
82
+
83
+ //
84
+ // Methods
85
+ //
86
+ Arguments() { }
87
+
88
+ // Non-grouped constructor without batching
89
+ Arguments(
90
+ cutlass::gemm::GemmCoord problem_size,
91
+ ElementNorm* block_Norm,
92
+ ElementSum* block_Sum
93
+ ):
94
+ problem_size(problem_size),
95
+ block_Norm(block_Norm),
96
+ block_Sum(block_Sum),
97
+ problem_sizes(nullptr),
98
+ offset_Norm_Device(nullptr),
99
+ offset_Sum_Device(nullptr),
100
+ batch_stride_Max(0),
101
+ batch_stride_Sum(0)
102
+ {
103
+
104
+ }
105
+
106
+ // Non-grouped constructor with batching
107
+ Arguments(
108
+ cutlass::gemm::GemmCoord problem_size,
109
+ ElementNorm* block_Norm,
110
+ ElementSum* block_Sum,
111
+ int64_t batch_stride_Max,
112
+ int64_t batch_stride_Sum
113
+ ):
114
+ problem_size(problem_size),
115
+ block_Norm(block_Norm),
116
+ block_Sum(block_Sum),
117
+ batch_stride_Max(batch_stride_Max),
118
+ batch_stride_Sum(batch_stride_Sum),
119
+ problem_sizes(nullptr),
120
+ offset_Norm_Device(nullptr),
121
+ offset_Sum_Device(nullptr)
122
+ {
123
+
124
+ }
125
+
126
+
127
+ // Grouped constructor
128
+ Arguments(
129
+ cutlass::gemm::GemmCoord *problem_sizes,
130
+ ElementNorm* block_Norm,
131
+ ElementSum* block_Sum,
132
+ int64_t* offset_Norm_Device,
133
+ int64_t* offset_Sum_Device
134
+ ):
135
+ problem_sizes(problem_sizes),
136
+ problem_size(cutlass::gemm::GemmCoord(0, 0, 0)),
137
+ block_Norm(block_Norm),
138
+ block_Sum(block_Sum),
139
+ offset_Norm_Device(offset_Norm_Device),
140
+ offset_Sum_Device(offset_Sum_Device)
141
+ {
142
+
143
+ }
144
+ };
145
+
146
+ struct SharedStorage {
147
+
148
+
149
+ };
150
+
151
+ //
152
+ // Params struct
153
+ //
154
+
155
+ struct Params {
156
+ Arguments args;
157
+
158
+ //
159
+ // Methods
160
+ //
161
+ Params() { }
162
+
163
+ Params(Arguments const &args_): args(args_) { }
164
+ };
165
+
166
+ private:
167
+
168
+ public:
169
+
170
+ CUTLASS_DEVICE
171
+ ApplySoftmaxFinalReduction() { }
172
+
173
+ CUTLASS_DEVICE
174
+ void operator()(Params const &params, SharedStorage &shared_storage) {
175
+
176
+ apply(params, shared_storage);
177
+ }
178
+
179
+ private:
180
+
181
+ /// Full reduction
182
+ CUTLASS_DEVICE
183
+ void apply(Params const &params, SharedStorage &shared_storage) {
184
+
185
+ int tid = threadIdx.x;
186
+ int bid = blockIdx.x;
187
+ int bdim = blockDim.x;
188
+
189
+ int block_batch = blockIdx.z;
190
+
191
+ // defining three vars for a general reduction module
192
+ cutlass::gemm::GemmCoord problem_size = isGroupedProblem ? params.args.problem_sizes[bid] : params.args.problem_size;
193
+ int m_dim_in_loop = isGroupedProblem ? problem_size.m() : tid + bdim;
194
+ int access_offset = isGroupedProblem ? 0 : bid * bdim;
195
+
196
+ if (!isGroupedProblem && access_offset + tid >= problem_size.m()) return;
197
+
198
+ ElementNorm *curr_ptr_Max = isGroupedProblem ? \
199
+ params.args.block_Norm + params.args.offset_Norm_Device[bid] : \
200
+ params.args.block_Norm + block_batch * params.args.batch_stride_Max;
201
+ ElementSum *curr_ptr_Sum = isGroupedProblem ? \
202
+ params.args.block_Sum + params.args.offset_Sum_Device[bid] : \
203
+ params.args.block_Sum + block_batch * params.args.batch_stride_Sum;
204
+
205
+ int threadblock_num = (problem_size.n() + ThreadblockShape::kN - 1) / ThreadblockShape::kN;
206
+
207
+ using ConvertSumOutput = cutlass::NumericConverter<ElementSum, ElementSoftmaxCompute>;
208
+ using ConvertNormOutput = cutlass::NumericConverter<ElementNorm, ElementSoftmaxCompute>;
209
+
210
+ using ConvertSum = cutlass::NumericConverter<ElementSoftmaxCompute, ElementSum>;
211
+ using ConvertNorm = cutlass::NumericConverter<ElementSoftmaxCompute, ElementNorm>;
212
+
213
+ ConvertSum convert_sum;
214
+ ConvertNorm convert_norm;
215
+
216
+ ConvertSumOutput convert_sum_output;
217
+ ConvertNormOutput convert_norm_output;
218
+
219
+ uint32_t float_max_bits = 0xff7fffff;
220
+ float min_float = reinterpret_cast<float const &>(float_max_bits);
221
+
222
+ CUTLASS_PRAGMA_UNROLL
223
+ for (int idx_m = tid; idx_m < m_dim_in_loop; idx_m += bdim) {
224
+ ElementNorm *access_n = curr_ptr_Max + idx_m + access_offset;
225
+ ElementSum *access_s = curr_ptr_Sum + idx_m + access_offset;
226
+ ElementNorm *access_n_bak = access_n;
227
+ ElementSum *access_s_bak = access_s;
228
+ ElementSoftmaxCompute max_val = ElementSoftmaxCompute(min_float);
229
+ ElementSoftmaxCompute sum_val = ElementSoftmaxCompute(0);
230
+ ElementNorm fetch_n;
231
+ ElementSum fetch_s;
232
+
233
+ CUTLASS_PRAGMA_UNROLL
234
+ for (int idx_n = 0; idx_n < threadblock_num; idx_n++) {
235
+ cutlass::arch::global_load<ElementNorm, sizeof(ElementNorm)>(fetch_n, access_n, true);
236
+ max_val = cutlass::fast_max(max_val, convert_norm(fetch_n));
237
+ access_n += problem_size.m();
238
+ }
239
+
240
+ access_n = access_n_bak;
241
+
242
+ CUTLASS_PRAGMA_UNROLL
243
+ for (int idx_n = 0; idx_n < threadblock_num; idx_n++) {
244
+ cutlass::arch::global_load<ElementNorm, sizeof(ElementNorm)>(fetch_n, access_n, true);
245
+ cutlass::arch::global_load<ElementSum, sizeof(ElementSum)>(fetch_s, access_s, true);
246
+ sum_val += convert_sum(fetch_s) * cutlass::fast_exp(convert_norm(fetch_n) - max_val);
247
+ access_n += problem_size.m();
248
+ access_s += problem_size.m();
249
+ }
250
+
251
+ ElementSoftmaxCompute inv_sum = cutlass::constants::one<ElementSoftmaxCompute>() / sum_val;
252
+
253
+ access_n = access_n_bak;
254
+ access_s = access_s_bak;
255
+
256
+ access_n[0] = convert_norm_output(max_val);
257
+ access_s[0] = convert_sum_output(inv_sum);
258
+ }
259
+
260
+ }
261
+ };
262
+
263
+ /////////////////////////////////////////////////////////////////////////////////////////////////
264
+
265
+ } // namespace kernel
266
+ } // namespace reduction
267
+ } // namespace cutlass
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Kernel performing a reduction over densely packed tensors in global memory
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/tensor_ref.h"
39
+ #include "cutlass/numeric_types.h"
40
+ #include "cutlass/array.h"
41
+ #include "cutlass/functional.h"
42
+ #include "cutlass/matrix_shape.h"
43
+ #include "cutlass/numeric_conversion.h"
44
+
45
+ #include "cutlass/layout/matrix.h"
46
+
47
+ /////////////////////////////////////////////////////////////////////////////////////////////////
48
+
49
+ namespace cutlass {
50
+ namespace reduction {
51
+ namespace kernel {
52
+
53
+ /////////////////////////////////////////////////////////////////////////////////////////////////
54
+
55
+ template <
56
+ typename Shape_, ///< shape of CTA (concept: MatrixShape)
57
+ typename OutputOp_ , ///< output operator (concept: epilogue::thread operator)
58
+ typename ReductionOp_, ///< reduction operator (concept: ReductionOperator)
59
+ int PartitionsPerStage = 4 ///< number of partitions to issue
60
+ >
61
+ class ReduceSplitK {
62
+ public:
63
+
64
+ using Shape = Shape_;
65
+ using ReductionOp = ReductionOp_;
66
+ using OutputOp = OutputOp_;
67
+ static int const kElementsPerAccess = OutputOp::kCount;
68
+ static int const kPartitionsPerStage = PartitionsPerStage;
69
+
70
+ using ElementWorkspace = typename ReductionOp::Element;
71
+ using ElementAccumulator = typename ReductionOp::ElementAccumulator;
72
+ using ElementOutput = typename OutputOp::ElementOutput;
73
+
74
+ using WorkspaceTensorRef = TensorRef<ElementWorkspace, layout::RowMajor>;
75
+ using OutputTensorRef = TensorRef<ElementOutput, layout::RowMajor>;
76
+ using StrideIndex = typename WorkspaceTensorRef::Layout::Stride::Index;
77
+
78
+ using FragmentWorkspace = AlignedArray<ElementWorkspace, kElementsPerAccess>;
79
+ using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
80
+ using FragmentOutput = AlignedArray<ElementOutput, kElementsPerAccess>;
81
+
82
+ //
83
+ // Types
84
+ //
85
+
86
+ /// Params structure
87
+ struct Params {
88
+
89
+ MatrixCoord problem_size;
90
+ int partitions;
91
+ size_t partition_stride;
92
+ WorkspaceTensorRef workspace;
93
+ OutputTensorRef destination;
94
+ OutputTensorRef source;
95
+ typename OutputOp::Params output;
96
+ typename ReductionOp::Params reduction;
97
+
98
+ //
99
+ // Methods
100
+ //
101
+
102
+ CUTLASS_HOST_DEVICE
103
+ Params() { }
104
+
105
+ CUTLASS_HOST_DEVICE
106
+ Params(
107
+ MatrixCoord problem_size_,
108
+ int partitions_,
109
+ size_t partition_stride_,
110
+ WorkspaceTensorRef workspace_,
111
+ OutputTensorRef destination_,
112
+ OutputTensorRef source_,
113
+ typename OutputOp::Params output_ = typename OutputOp::Params(),
114
+ typename ReductionOp::Params reduction_ = typename ReductionOp::Params()
115
+ ):
116
+ problem_size(problem_size_),
117
+ partitions(partitions_),
118
+ partition_stride(sizeof(FragmentWorkspace) * partition_stride_ / kElementsPerAccess),
119
+ workspace(workspace_),
120
+ destination(destination_),
121
+ source(source_),
122
+ output(output_),
123
+ reduction(reduction_) {
124
+
125
+ }
126
+ };
127
+
128
+ struct SharedStorage { };
129
+
130
+
131
+ public:
132
+
133
+ /// Computes the grid size given a chosen threadblock shape
134
+ CUTLASS_HOST_DEVICE
135
+ static dim3 grid_shape(
136
+ cutlass::MatrixCoord problem_size) {
137
+
138
+ return dim3(
139
+ (problem_size.row() + Shape::kRow - 1) / Shape::kRow,
140
+ (problem_size.column() + Shape::kColumn - 1) / Shape::kColumn);
141
+ }
142
+
143
+ /// Determines the threadblock shape
144
+ CUTLASS_HOST_DEVICE
145
+ static dim3 block_shape() {
146
+ return dim3(Shape::kColumn / kElementsPerAccess, Shape::kRow);
147
+ }
148
+
149
+ /// Perform a reduction
150
+ CUTLASS_DEVICE
151
+ void operator()(Params const &params, SharedStorage &storage) {
152
+
153
+ // Determine CTA position
154
+ MatrixCoord thread_offset(
155
+ MatrixCoord::Index(int(blockIdx.x) * Shape::kRow + threadIdx.y),
156
+ MatrixCoord::Index(int(blockIdx.y) * Shape::kColumn + threadIdx.x * kElementsPerAccess)
157
+ );
158
+
159
+ // One guard conditional
160
+ if (!(thread_offset.row() < params.problem_size.row() &&
161
+ thread_offset.column() < params.problem_size.column())) {
162
+
163
+ return;
164
+ }
165
+
166
+
167
+ ReductionOp reduction_op(params.reduction);
168
+
169
+ FragmentAccumulator accumulator;
170
+
171
+ accumulator.clear();
172
+
173
+ //
174
+ // Load the first slice
175
+ //
176
+
177
+ char const *workspace_ptr =
178
+ reinterpret_cast<char const *>(
179
+ params.workspace.data() + params.workspace.offset(thread_offset));
180
+
181
+ FragmentWorkspace workspace_frag[kPartitionsPerStage];
182
+
183
+ //
184
+ // Construct the output operator
185
+ //
186
+
187
+ OutputOp output_op(params.output);
188
+
189
+ //
190
+ // Load and accumulate with a simple batched loading sequence.
191
+ //
192
+
193
+ CUTLASS_PRAGMA_NO_UNROLL
194
+ for (int k = 0; k < params.partitions; k += kPartitionsPerStage) {
195
+
196
+ CUTLASS_PRAGMA_UNROLL
197
+ for (int i = 0; i < kPartitionsPerStage; ++i) {
198
+ if (k + i < params.partitions) {
199
+ workspace_frag[i] = *reinterpret_cast<FragmentWorkspace const *>(workspace_ptr);
200
+ workspace_ptr += params.partition_stride;
201
+ }
202
+ }
203
+
204
+ CUTLASS_PRAGMA_UNROLL
205
+ for (int i = 0; i < kPartitionsPerStage; ++i) {
206
+ if (k + i < params.partitions) {
207
+ accumulator = reduction_op(accumulator, workspace_frag[i]);
208
+ }
209
+ }
210
+ }
211
+
212
+ //
213
+ // Conditionally load the source
214
+ //
215
+
216
+ FragmentOutput source_frag;
217
+
218
+ source_frag.clear();
219
+
220
+ FragmentOutput const *source_ptr = reinterpret_cast<FragmentOutput const *>(
221
+ params.source.data() + params.source.offset(thread_offset));
222
+
223
+ if (output_op.is_source_needed()) {
224
+ reinterpret_cast<FragmentOutput &>(source_frag) = *source_ptr;
225
+ }
226
+
227
+ //
228
+ // Compute the output
229
+ //
230
+
231
+ typename OutputOp::FragmentOutput output_frag = output_op(accumulator, source_frag);
232
+
233
+ //
234
+ // Store
235
+ //
236
+
237
+ FragmentOutput *dest_ptr = reinterpret_cast<FragmentOutput *>(
238
+ params.destination.data() + params.destination.offset(thread_offset));
239
+
240
+ *dest_ptr = reinterpret_cast<FragmentOutput const &>(output_frag);
241
+ }
242
+ };
243
+
244
+ /////////////////////////////////////////////////////////////////////////////////////////////////
245
+
246
+ } // namespace kernel
247
+ } // namespace reduction
248
+ } // namespace cutlass
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Kernel performing a reduction over one or more ranks of an affine tensor
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/array.h"
39
+ #include "cutlass/fast_math.h"
40
+ #include "cutlass/numeric_types.h"
41
+ #include "cutlass/numeric_conversion.h"
42
+ #include "cutlass/device_kernel.h"
43
+
44
+ #include "cutlass/reduction/thread/reduction_operators.h"
45
+
46
+ /////////////////////////////////////////////////////////////////////////////////////////////////
47
+
48
+ namespace cutlass {
49
+ namespace reduction {
50
+ namespace kernel {
51
+
52
+ /////////////////////////////////////////////////////////////////////////////////////////////////
53
+
54
+ /// Parameters structure
55
+ template <
56
+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5)
57
+ int ReducedRank, ///< Rank of reduced tensor (i.e. number of outer ranks)
58
+ typename ElementOutput, ///< Data type of output tensor
59
+ typename ElementSource, ///< Data type of source tensor
60
+ typename ReductionOp, ///< Reduction operator
61
+ int VectorLength = 1, ///< Vector length for memory
62
+ typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation
63
+ int Threads = 256, ///< Number of participating threads
64
+ int BatchSize = 4 ///< Number of elements to load per batch
65
+ >
66
+ struct TensorReductionAffineContiguousParams {
67
+
68
+ static int const kRank = Rank;
69
+ static int const kReducedRank = ReducedRank;
70
+ static int const kVectorLength = VectorLength;
71
+ static int const kInnerRank = kRank - kReducedRank;
72
+ static int const kThreads = Threads;
73
+ static int const kBatchSize = BatchSize;
74
+
75
+ Coord<kRank> extent; /// Extent of source tensor
76
+ FastDivmodU64 divmod[kRank - 1]; /// FastDivmod by each strided rank
77
+ int64_t dst_stride[kReducedRank]; /// stride (units of bytes) - I, J
78
+ int64_t src_stride[kRank - 1]; /// stride (units of bytes) - I, J, K
79
+ int64_t workspace_stride; /// stride (units of bytes) between workspace
80
+ int workspace_count; /// number of workspaces
81
+
82
+ uint64_t inner_count; /// Number of elements in reduced index space
83
+ uint64_t outer_count; /// Number of elements in outer index space
84
+
85
+ ElementOutput * destination; /// Pointer to output tensor of rank kReducedRank
86
+ ElementSource const * source; /// Pointer to source pointer of rank kRank
87
+ ReductionOp reduction_op; /// Reduction operator
88
+ ElementCompute reduction_identity; /// Identity element used by reduction operator
89
+ ElementCompute *device_workspace; /// Pointer to device workspace for inter-CTA reductions
90
+
91
+ //
92
+ // Methods
93
+ //
94
+
95
+ /// Ctor
96
+ CUTLASS_HOST_DEVICE
97
+ TensorReductionAffineContiguousParams() {
98
+
99
+ }
100
+
101
+ /// Ctor
102
+ TensorReductionAffineContiguousParams(
103
+ Coord<kRank> extent_, ///< Extent of source tensor
104
+ ElementOutput * dst_ptr_, ///< Output tensor data
105
+ int64_t dst_stride_[], ///< Stride (units of elements)
106
+ ElementSource const * src_ptr_, ///< Source tensor data
107
+ int64_t src_stride_[], ///< Stride (units of elements)
108
+ ElementCompute *device_workspace_, ///< Pointer to device workspace for inter-CTA reductions
109
+ int64_t workspace_stride_, ///< Stride between workspaces
110
+ int workspace_count_, ///< Number of workspaces
111
+ ReductionOp reduction_op_, ///< Reduction operator
112
+ ElementCompute reduction_identity_ = ElementCompute() ///< Identity element used by reduction operator
113
+ ):
114
+ extent(extent_),
115
+ inner_count(1),
116
+ outer_count(1),
117
+ destination(dst_ptr_),
118
+ source(src_ptr_),
119
+ device_workspace(device_workspace_),
120
+ workspace_stride(workspace_stride_),
121
+ workspace_count(workspace_count_),
122
+ reduction_op(reduction_op_),
123
+ reduction_identity(reduction_identity_) {
124
+
125
+ // Initialize divisors for fast div-mod
126
+ for (int p = 1; p < kRank; ++p) {
127
+ divmod[p - 1] = FastDivmodU64(uint64_t(extent[p]));
128
+ }
129
+
130
+ int input_size_bits = sizeof_bits<ElementSource>::value;
131
+ int output_size_bits = sizeof_bits<ElementOutput>::value;
132
+
133
+ // Compute strides in units of bytes
134
+ for (int p = 0; p < kReducedRank; ++p) {
135
+ dst_stride[p] = dst_stride_[p] * output_size_bits / 8;
136
+ }
137
+
138
+ for (int p = 0; p < kRank - 1; ++p) {
139
+ src_stride[p] = src_stride_[p] * input_size_bits / 8;
140
+ }
141
+
142
+ // Compute number of elements in strided ranks
143
+ for (int p = 0; p < kReducedRank; ++p) {
144
+ outer_count *= uint64_t(extent[p]);
145
+ }
146
+
147
+ for (int p = 0; p < kInnerRank; ++p) {
148
+ inner_count *= uint64_t(extent[kRank - 1 - p]);
149
+ }
150
+ }
151
+ };
152
+
153
+ /////////////////////////////////////////////////////////////////////////////////////////////////
154
+
155
+ /// Kernel to reduce a tensor with affine layout over a set of ranks *INCLUDING* the contiguous
156
+ /// rank. This leads to favorable vectorized memory accesses over the contiguous rank.
157
+ template <
158
+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5)
159
+ int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2)
160
+ typename ElementOutput, ///< Data type of output tensor
161
+ typename ElementSource, ///< Data type of source tensor
162
+ typename ReductionOp, ///< Reduction operator
163
+ int VectorLength = 1, ///< Vector length for memory
164
+ typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation
165
+ int Threads = 256, ///< Number of participating threads
166
+ int BatchSize = 4 ///< Number of elements to load per batch
167
+ >
168
+ class TensorReductionAffineContiguous {
169
+ public:
170
+
171
+ static int const kRank = Rank;
172
+ static int const kReducedRank = ReducedRank;
173
+ static int const kVectorLength = VectorLength;
174
+ static int const kInnerRank = kRank - kReducedRank;
175
+ static int const kThreads = Threads;
176
+ static int const kBatchSize = BatchSize;
177
+ using ComputeFragment = Array<ElementCompute, VectorLength>;
178
+ using SourceFragment = AlignedArray<ElementSource, VectorLength>;
179
+ using OutputFragment = AlignedArray<ElementOutput, VectorLength>;
180
+
181
+ /// Shared memory allocation used for reduction within the CTA
182
+ struct SharedStorage {
183
+ Array<ElementCompute, kThreads * kVectorLength> workspace;
184
+ };
185
+
186
+ /// Parameters structure
187
+ using Params = TensorReductionAffineContiguousParams<
188
+ Rank,
189
+ ReducedRank,
190
+ ElementOutput,
191
+ ElementSource,
192
+ ReductionOp,
193
+ VectorLength,
194
+ ElementCompute,
195
+ Threads,
196
+ BatchSize
197
+ >;
198
+
199
+ private:
200
+
201
+ /// Computes the coordinate and offset of a given linear index
202
+ CUTLASS_DEVICE
203
+ void compute_inner_coord_and_offset_(
204
+ Params const &params,
205
+ Coord<kInnerRank> & coord,
206
+ int64_t &src_offset,
207
+ uint64_t linear_idx) const {
208
+
209
+ // Decompose into a coordinate of rank <kInnerRank>
210
+ coord = CoordinateDecomposition<kInnerRank>(linear_idx, &params.divmod[kRank - kInnerRank]);
211
+
212
+ // Compute an offset using the souce stride
213
+ src_offset = 0;
214
+ CUTLASS_PRAGMA_UNROLL
215
+ for (int i = 0; i < kInnerRank - 1; ++i) {
216
+ src_offset += coord[i] * params.src_stride[kReducedRank + i];
217
+ }
218
+ src_offset += coord[kInnerRank - 1] * sizeof_bits<ElementSource>::value / 8;
219
+ }
220
+
221
+ /// Computes the coordinate and offset of a given linear index
222
+ CUTLASS_DEVICE
223
+ void compute_outer_coord_and_offset_(
224
+ Params const &params,
225
+ Coord<kReducedRank> & coord,
226
+ int64_t &dst_offset,
227
+ int64_t &src_offset,
228
+ uint64_t linear_idx) const {
229
+
230
+ // Decompose into coordinate of rank <kReducedRank>
231
+ coord = CoordinateDecomposition<kReducedRank>(linear_idx, params.divmod);
232
+
233
+ // Compute offsets using destination and source strides
234
+ dst_offset = 0;
235
+ src_offset = 0;
236
+
237
+ CUTLASS_PRAGMA_UNROLL
238
+ for (int i = 0; i < kReducedRank; ++i) {
239
+ dst_offset += params.dst_stride[i] * coord[i];
240
+ src_offset += params.src_stride[i] * coord[i];
241
+ }
242
+ }
243
+
244
+ /// Reduces over the reduction indices yielding a single element
245
+ CUTLASS_DEVICE
246
+ ElementCompute reduce_indices_(
247
+ Params const &params,
248
+ ElementCompute *threadblock_workspace,
249
+ char const *src_byte_ptr,
250
+ int coord_c) {
251
+
252
+ NumericArrayConverter<ElementCompute, ElementSource, VectorLength> convert_source;
253
+ ReductionOp reduction_op(params.reduction_op);
254
+
255
+ //
256
+ // Early exit or initialize to identity element
257
+ //
258
+ if (!params.inner_count) {
259
+ return params.reduction_identity;
260
+ }
261
+
262
+ ComputeFragment accumulator;
263
+
264
+ CUTLASS_PRAGMA_UNROLL
265
+ for (int i = 0; i < int(accumulator.size()); ++i) {
266
+ accumulator[i] = params.reduction_identity;
267
+ }
268
+
269
+ // Compute the coordinate of the first access
270
+ int64_t src_byte_offset = 0;
271
+ Coord<kInnerRank> coord;
272
+
273
+ uint64_t linear_idx = (threadIdx.x + blockDim.x * threadIdx.z + blockDim.x * blockIdx.z * blockDim.z) * kVectorLength;
274
+ compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx);
275
+
276
+ // Load the first vector
277
+ SourceFragment source_fragment[kBatchSize];
278
+
279
+ bool not_done = true;
280
+
281
+ // Iterate over vectors in a linearized reduction index space
282
+ while (not_done) {
283
+
284
+ bool guards[kBatchSize];
285
+
286
+ // Issue a batch of loads
287
+ CUTLASS_PRAGMA_UNROLL
288
+ for (int b = 0; b < kBatchSize; ++b) {
289
+
290
+ if (linear_idx < params.inner_count) {
291
+ source_fragment[b] = *reinterpret_cast<SourceFragment const *>(src_byte_ptr + src_byte_offset);
292
+ guards[b] = true;
293
+ }
294
+ else {
295
+ guards[b] = false;
296
+ not_done = false;
297
+ }
298
+
299
+ linear_idx += (blockDim.z * gridDim.z * blockDim.x) * kVectorLength;
300
+ compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx);
301
+ }
302
+
303
+ // Perform a batch of reduction operations
304
+ CUTLASS_PRAGMA_UNROLL
305
+ for (int b = 0; b < kBatchSize; ++b) {
306
+ if (guards[b]) {
307
+ auto cvt = convert_source(source_fragment[b]);
308
+
309
+ accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator(
310
+ reduction_op,
311
+ accumulator,
312
+ cvt);
313
+ }
314
+ }
315
+ };
316
+
317
+ //
318
+ // Reduction of vectors to scalar
319
+ //
320
+
321
+ ElementCompute reduced_accumulator = accumulator[0];
322
+
323
+ CUTLASS_PRAGMA_UNROLL
324
+ for (int i = 1; i < kVectorLength; ++i) {
325
+ reduced_accumulator = reduction_op(reduced_accumulator, accumulator[i]);
326
+ }
327
+
328
+ //
329
+ // Reduction within CTA across threadIdx.xz => threadIdx{.x = 0, .z = 0}
330
+ //
331
+ // This re-arranges data so threadIdx.y is effectively a row index and threadIdx.xz is a column
332
+ //
333
+
334
+ int thread_count = blockDim.x * blockDim.z;
335
+ int thread_j = threadIdx.x + blockDim.x * threadIdx.z;
336
+ int thread_i = threadIdx.y;
337
+
338
+ ElementCompute *frag_ptr = reinterpret_cast<ElementCompute *>(threadblock_workspace) + thread_i * thread_count;
339
+
340
+ frag_ptr[thread_j] = reduced_accumulator;
341
+
342
+ //
343
+ // Reduce
344
+ //
345
+ CUTLASS_PRAGMA_NO_UNROLL
346
+ while (thread_count > 1) {
347
+ thread_count /= 2;
348
+
349
+ __syncthreads();
350
+
351
+ if (thread_j < thread_count) {
352
+ ElementCompute other = frag_ptr[thread_j + thread_count];
353
+
354
+ reduced_accumulator = reduction_op(reduced_accumulator, other);
355
+
356
+ frag_ptr[thread_j] = reduced_accumulator;
357
+ }
358
+
359
+ __syncthreads();
360
+ }
361
+
362
+
363
+ return reduced_accumulator;
364
+ }
365
+
366
+ public:
367
+
368
+ /// Perform a reduction
369
+ CUTLASS_DEVICE
370
+ void operator()(Params const &params, SharedStorage &shared_storage) {
371
+
372
+ int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength;
373
+
374
+ char const * src_byte_ptr = reinterpret_cast<char const *>(params.source);
375
+ char * dst_byte_ptr = nullptr;
376
+
377
+ // If performing a reduction across CTAs, redirect output to device workspace
378
+ if (gridDim.z == 1) {
379
+ dst_byte_ptr = reinterpret_cast<char *>(params.destination);
380
+ }
381
+ else {
382
+ dst_byte_ptr = reinterpret_cast<char *>(params.device_workspace);
383
+ }
384
+
385
+ uint64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y;
386
+
387
+ // Use modulo division to compute location
388
+ Coord<kReducedRank> outer_coord;
389
+ int64_t dst_byte_offset;
390
+ int64_t src_byte_offset;
391
+
392
+ compute_outer_coord_and_offset_(
393
+ params,
394
+ outer_coord,
395
+ dst_byte_offset,
396
+ src_byte_offset,
397
+ idx_linear);
398
+
399
+ if (gridDim.z == 1) {
400
+
401
+ /// Complete the reduction with no workspace
402
+ while (idx_linear < params.outer_count) {
403
+
404
+ ElementCompute result = reduce_indices_(
405
+ params,
406
+ shared_storage.workspace.data(),
407
+ src_byte_ptr + src_byte_offset,
408
+ coord_c);
409
+
410
+ // Store the result after possible final reduction within the CTA
411
+ if (threadIdx.z == 0 && threadIdx.x == 0) {
412
+
413
+ // Convert to output type and store
414
+ NumericConverter<ElementOutput, ElementCompute> convert_output;
415
+ ElementOutput cvt = convert_output(result);
416
+
417
+ *reinterpret_cast<ElementOutput *>(dst_byte_ptr + dst_byte_offset) = cvt;
418
+ }
419
+
420
+ __syncthreads();
421
+
422
+ // Update indices and pointers
423
+ idx_linear += gridDim.y * blockDim.y;
424
+
425
+ compute_outer_coord_and_offset_(
426
+ params,
427
+ outer_coord,
428
+ dst_byte_offset,
429
+ src_byte_offset,
430
+ idx_linear);
431
+
432
+ } // while
433
+ }
434
+ else {
435
+
436
+ /// Complete the reduction with workspace
437
+ while (idx_linear < params.outer_count) {
438
+
439
+ ElementCompute result = reduce_indices_(
440
+ params,
441
+ shared_storage.workspace.data(),
442
+ src_byte_ptr + src_byte_offset,
443
+ coord_c);
444
+
445
+ int64_t byte_offset =
446
+ blockIdx.z * params.workspace_stride + idx_linear * sizeof_bits<ElementCompute>::value / 8;
447
+
448
+ // Store the result for final reduction
449
+ if (threadIdx.z == 0 && threadIdx.x == 0) {
450
+ *reinterpret_cast<ElementCompute *>(dst_byte_ptr + byte_offset) = result;
451
+ }
452
+
453
+ __syncthreads();
454
+
455
+ // Update indices and pointers
456
+ idx_linear += gridDim.y * blockDim.y;
457
+
458
+ compute_outer_coord_and_offset_(
459
+ params,
460
+ outer_coord,
461
+ dst_byte_offset,
462
+ src_byte_offset,
463
+ idx_linear);
464
+ } // while
465
+ }
466
+ }
467
+ };
468
+
469
+ /////////////////////////////////////////////////////////////////////////////////////////////////
470
+
471
+ /// Kernel to perform final reduction
472
+ template <
473
+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5)
474
+ int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2)
475
+ typename ElementOutput, ///< Data type of output tensor
476
+ typename ElementSource, ///< Data type of source tensor
477
+ typename ReductionOp, ///< Reduction operator
478
+ int VectorLength = 1, ///< Vector length for memory
479
+ typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation
480
+ int Threads = 256, ///< Number of participating threads
481
+ int BatchSize = 4 ///< Number of elements to load per batch
482
+ >
483
+ class TensorReductionAffineContiguousFinal {
484
+ public:
485
+
486
+ static int const kRank = Rank;
487
+ static int const kReducedRank = ReducedRank;
488
+ static int const kVectorLength = VectorLength;
489
+ static int const kInnerRank = kRank - kReducedRank;
490
+ static int const kThreads = Threads;
491
+ static int const kBatchSize = BatchSize;
492
+
493
+ /// Shared memory
494
+ struct SharedStorage { };
495
+
496
+ /// Parameters structure
497
+ using Params = TensorReductionAffineContiguousParams<
498
+ Rank,
499
+ ReducedRank,
500
+ ElementOutput,
501
+ ElementSource,
502
+ ReductionOp,
503
+ VectorLength,
504
+ ElementCompute,
505
+ Threads,
506
+ BatchSize
507
+ >;
508
+
509
+ private:
510
+
511
+ /// Computes the coordinate and offset of a given linear index
512
+ CUTLASS_DEVICE
513
+ void compute_outer_coord_and_offset_(
514
+ Params const &params,
515
+ Coord<kReducedRank> & coord,
516
+ int64_t &dst_offset,
517
+ uint64_t linear_idx) const {
518
+
519
+ // Decompose into coordinate of rank <kReducedRank>
520
+ coord = CoordinateDecomposition<kReducedRank>(linear_idx, params.divmod);
521
+
522
+ // Compute offsets using destination and source strides
523
+ dst_offset = 0;
524
+
525
+ CUTLASS_PRAGMA_UNROLL
526
+ for (int i = 0; i < kReducedRank; ++i) {
527
+ dst_offset += params.dst_stride[i] * coord[i];
528
+ }
529
+ }
530
+
531
+ /// Reduces over the reduction indices
532
+ CUTLASS_DEVICE
533
+ ElementCompute reduce_indices_(
534
+ Params const &params,
535
+ ElementCompute const *device_workspace) {
536
+
537
+ ReductionOp reduction_op(params.reduction_op);
538
+ char const *src_byte_ptr = reinterpret_cast<char const *>(device_workspace);
539
+
540
+ // Accumulated output
541
+ ElementCompute accumulator = params.reduction_identity;
542
+
543
+ for (int iter = 0; iter < params.workspace_count; ++iter) {
544
+ ElementCompute workspace_item = *reinterpret_cast<ElementCompute const *>(src_byte_ptr);
545
+
546
+ accumulator = reduction_op(accumulator, workspace_item);
547
+
548
+ src_byte_ptr += params.workspace_stride;
549
+ }
550
+
551
+ return accumulator;
552
+ }
553
+
554
+ public:
555
+
556
+ //
557
+ // Methods
558
+ //
559
+
560
+ /// Perform a reduction
561
+ CUTLASS_DEVICE
562
+ void operator()(Params const &params, SharedStorage &shared_storage) {
563
+
564
+ uint64_t idx_linear = blockIdx.x * blockDim.x + threadIdx.x;
565
+
566
+ char * dst_byte_ptr = reinterpret_cast<char *>(params.destination);
567
+
568
+ // Use modulo division to compute location
569
+ Coord<kReducedRank> outer_coord;
570
+ int64_t dst_byte_offset;
571
+
572
+ compute_outer_coord_and_offset_(
573
+ params,
574
+ outer_coord,
575
+ dst_byte_offset,
576
+ idx_linear);
577
+
578
+ /// Complete the reduction
579
+ while (idx_linear < params.outer_count) {
580
+
581
+ ElementCompute result = reduce_indices_(params, params.device_workspace + idx_linear);
582
+
583
+ // Convert to output type and store
584
+ NumericConverter<ElementOutput, ElementCompute> convert_output;
585
+
586
+ *reinterpret_cast<ElementOutput *>(dst_byte_ptr + dst_byte_offset) = convert_output(result);
587
+
588
+ // Update indices and pointers
589
+ idx_linear += gridDim.x * blockDim.x;
590
+
591
+ compute_outer_coord_and_offset_(
592
+ params,
593
+ outer_coord,
594
+ dst_byte_offset,
595
+ idx_linear);
596
+ }
597
+ }
598
+ };
599
+
600
+ /////////////////////////////////////////////////////////////////////////////////////////////////
601
+
602
+ } // namespace kernel
603
+ } // namespace reduction
604
+ } // namespace cutlass
605
+
606
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Kernel performing a reduction over one or more ranks of an affine tensor
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/array.h"
39
+ #include "cutlass/fast_math.h"
40
+ #include "cutlass/numeric_types.h"
41
+ #include "cutlass/numeric_conversion.h"
42
+ #include "cutlass/device_kernel.h"
43
+
44
+ #include "cutlass/reduction/thread/reduction_operators.h"
45
+
46
+ /////////////////////////////////////////////////////////////////////////////////////////////////
47
+
48
+ namespace cutlass {
49
+ namespace reduction {
50
+
51
+ /////////////////////////////////////////////////////////////////////////////////////////////////
52
+
53
+ namespace kernel {
54
+
55
+ /// Parameters structure
56
+ template <
57
+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5)
58
+ int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2)
59
+ typename ElementOutput, ///< Data type of output tensor
60
+ typename ElementSource, ///< Data type of source tensor
61
+ typename ReductionOp, ///< Reduction operator
62
+ int VectorLength = 1, ///< Vector length for memory
63
+ typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation
64
+ int Threads = 256, ///< Number of participating threads
65
+ int BatchSize = 4 ///< Number of elements to load per batch
66
+ >
67
+ struct TensorReductionAffineStridedParams {
68
+
69
+ static int const kRank = Rank;
70
+ static int const kReducedRank = ReducedRank;
71
+ static int const kVectorLength = VectorLength;
72
+ static int const kInnerRank = kRank - kReducedRank;
73
+ static int const kThreads = Threads;
74
+ static int const kBatchSize = BatchSize;
75
+
76
+ Coord<kRank> extent; /// Extent of source tensor
77
+ FastDivmodU64 divmod[kRank - 1]; /// FastDivmod by each strided rank
78
+ int64_t dst_stride[kReducedRank - 1]; /// stride (units of bytes) - I, J
79
+ int64_t src_stride[kRank - 1]; /// stride (units of bytes) - I, J, K
80
+ int64_t workspace_stride; /// stride (units of bytes) between workspace
81
+ int64_t workspace_outer_stride; /// stride (units of bytes) between 'rows' of the workspace
82
+ int workspace_count; /// number of workspaces
83
+
84
+ uint64_t inner_count; /// Number of elements in reduced index space
85
+ uint64_t outer_count; /// Number of elements in outer index space
86
+
87
+ ElementOutput * destination; /// Pointer to output tensor of rank kReducedRank
88
+ ElementSource const * source; /// Pointer to source pointer of rank kRank
89
+ ReductionOp reduction_op; /// Reduction operator
90
+ ElementCompute reduction_identity; /// Identity element for reduction operator
91
+ ElementCompute *device_workspace; /// Pointer to device workspace for inter-CTA reductions
92
+
93
+ //
94
+ // Methods
95
+ //
96
+
97
+ /// Ctor
98
+ CUTLASS_HOST_DEVICE
99
+ TensorReductionAffineStridedParams() {
100
+
101
+ }
102
+
103
+ /// Ctor
104
+ TensorReductionAffineStridedParams(
105
+ Coord<kRank> extent_, ///< Extent of source tensor
106
+ ElementOutput * dst_ptr_, ///< Output tensor data
107
+ int64_t dst_stride_[], ///< Stride (units of elements)
108
+ ElementSource const * src_ptr_, ///< Source tensor data
109
+ int64_t src_stride_[], ///< Stride (units of elements)
110
+ ElementCompute *device_workspace_, ///< Pointer to device workspace for inter-CTA reductions
111
+ int64_t workspace_stride_, ///< Stride between workspaces
112
+ int workspace_count_, ///< Number of workspaces
113
+ ReductionOp reduction_op_, ///< Reduction operator
114
+ ElementCompute reduction_identity_ = ElementCompute() ///< Identity element for reduction operator
115
+ ):
116
+ extent(extent_),
117
+ inner_count(1),
118
+ outer_count(1),
119
+ destination(dst_ptr_),
120
+ source(src_ptr_),
121
+ device_workspace(device_workspace_),
122
+ workspace_outer_stride(0),
123
+ workspace_stride(workspace_stride_),
124
+ workspace_count(workspace_count_),
125
+ reduction_op(reduction_op_),
126
+ reduction_identity(reduction_identity_) {
127
+
128
+ // Initialize divisors for fast div-mod
129
+ for (int p = 1; p < kRank; ++p) {
130
+ divmod[p - 1] = FastDivmodU64(uint64_t(extent[p]));
131
+ }
132
+
133
+ int input_size_bits = sizeof_bits<ElementSource>::value;
134
+ int output_size_bits = sizeof_bits<ElementOutput>::value;
135
+
136
+ workspace_outer_stride = workspace_stride * workspace_count;
137
+
138
+ // Compute strides in units of bytes
139
+ for (int p = 0; p < kReducedRank - 1; ++p) {
140
+ dst_stride[p] = dst_stride_[p] * output_size_bits / 8;
141
+ }
142
+
143
+ for (int p = 0; p < kRank - 1; ++p) {
144
+ src_stride[p] = src_stride_[p] * input_size_bits / 8;
145
+ }
146
+
147
+ // Compute number of elements in strided ranks
148
+ for (int p = 0; p < kReducedRank - 1; ++p) {
149
+ outer_count *= uint64_t(extent[p]);
150
+ }
151
+
152
+ for (int p = 0; p < kInnerRank; ++p) {
153
+ inner_count *= uint64_t(extent[kReducedRank + p - 1]);
154
+ }
155
+ }
156
+ };
157
+
158
+ /// Kernel to reduce a tensor with affine layout over a set of ranks *EXCLUDING* the contiguous
159
+ /// rank. This leads to favorable vectorized memory accesses over the contiguous rank.
160
+ template <
161
+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5)
162
+ int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2)
163
+ typename ElementOutput, ///< Data type of output tensor
164
+ typename ElementSource, ///< Data type of source tensor
165
+ typename ReductionOp, ///< Reduction operator
166
+ int VectorLength = 1, ///< Vector length for memory
167
+ typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation
168
+ int Threads = 256, ///< Number of participating threads
169
+ int BatchSize = 4 ///< Number of elements to load per batch
170
+ >
171
+ class TensorReductionAffineStrided {
172
+ public:
173
+
174
+ static int const kRank = Rank;
175
+ static int const kReducedRank = ReducedRank;
176
+ static int const kVectorLength = VectorLength;
177
+ static int const kInnerRank = kRank - kReducedRank;
178
+ static int const kThreads = Threads;
179
+ static int const kBatchSize = BatchSize;
180
+ using ComputeFragment = Array<ElementCompute, VectorLength>;
181
+ using SourceFragment = AlignedArray<ElementSource, VectorLength>;
182
+ using OutputFragment = AlignedArray<ElementOutput, VectorLength>;
183
+
184
+ /// Shared memory allocation used for reduction within the CTA
185
+ struct SharedStorage {
186
+ Array<ElementCompute, kThreads * kVectorLength> workspace;
187
+ };
188
+
189
+ /// Parameters structure
190
+ using Params = TensorReductionAffineStridedParams<
191
+ Rank,
192
+ ReducedRank,
193
+ ElementOutput,
194
+ ElementSource,
195
+ ReductionOp,
196
+ VectorLength,
197
+ ElementCompute,
198
+ Threads,
199
+ BatchSize
200
+ >;
201
+
202
+ private:
203
+
204
+ /// Computes the coordinate and offset of a given linear index
205
+ CUTLASS_DEVICE
206
+ void compute_inner_coord_and_offset_(
207
+ Params const &params,
208
+ Coord<kInnerRank> & coord,
209
+ int64_t &src_offset,
210
+ uint64_t linear_idx) const {
211
+
212
+ // Decompose into coordinate
213
+ coord = CoordinateDecomposition<kInnerRank>(linear_idx, &params.divmod[kReducedRank - 1]);
214
+
215
+ // Compute linear offset
216
+ src_offset = 0;
217
+
218
+ CUTLASS_PRAGMA_UNROLL
219
+ for (int i = 0; i < kInnerRank; ++i) {
220
+ src_offset += params.src_stride[kReducedRank + i - 1] * coord[i];
221
+ }
222
+ }
223
+
224
+ /// Computes the coordinate and offset of a given linear index
225
+ CUTLASS_DEVICE
226
+ void compute_outer_coord_and_offset_(
227
+ Params const &params,
228
+ Coord<kReducedRank - 1> & coord,
229
+ int64_t &dst_offset,
230
+ int64_t &src_offset,
231
+ uint64_t linear_idx) const {
232
+
233
+ // Decompose linear coordinate
234
+ coord = CoordinateDecomposition<kReducedRank - 1>(linear_idx, params.divmod);
235
+
236
+ // Compute offset into tensors
237
+ dst_offset = 0;
238
+ src_offset = 0;
239
+
240
+ CUTLASS_PRAGMA_UNROLL
241
+ for (int i = 0; i < kReducedRank - 1; ++i) {
242
+ dst_offset += params.dst_stride[i] * coord[i];
243
+ src_offset += params.src_stride[i] * coord[i];
244
+ }
245
+ }
246
+
247
+ /// Reduces over the reduction indices
248
+ CUTLASS_DEVICE
249
+ ComputeFragment reduce_indices_(
250
+ Params const &params,
251
+ ElementCompute *threadblock_workspace,
252
+ char const *src_byte_ptr) {
253
+
254
+ NumericArrayConverter<ElementCompute, ElementSource, VectorLength> convert_source;
255
+ ReductionOp reduction_op(params.reduction_op);
256
+
257
+ // Accumulated output
258
+ ComputeFragment identity_frag;
259
+
260
+ CUTLASS_PRAGMA_UNROLL
261
+ for (int i = 0; i < int(identity_frag.size()); ++i) {
262
+ identity_frag[i] = params.reduction_identity;
263
+ }
264
+
265
+ if (!params.inner_count) {
266
+ return identity_frag;
267
+ }
268
+
269
+ ComputeFragment accumulator = identity_frag;
270
+
271
+ // Compute the coordinate of the first access
272
+ int64_t src_byte_offset = 0;
273
+ Coord<kInnerRank> coord;
274
+
275
+ uint64_t linear_idx = threadIdx.z + blockIdx.z * blockDim.z;
276
+ compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx);
277
+
278
+ // Load the first vector
279
+ SourceFragment source_fragment[kBatchSize];
280
+
281
+ bool not_done = true;
282
+
283
+ // Iterate over vectors in a linearized reduction index space
284
+ while (not_done) {
285
+
286
+ bool guards[kBatchSize];
287
+
288
+ // Issue a batch of loads
289
+ CUTLASS_PRAGMA_UNROLL
290
+ for (int b = 0; b < kBatchSize; ++b) {
291
+
292
+ if (linear_idx < params.inner_count) {
293
+ source_fragment[b] = *reinterpret_cast<SourceFragment const *>(src_byte_ptr + src_byte_offset);
294
+ guards[b] = true;
295
+ }
296
+ else {
297
+ guards[b] = false;
298
+ not_done = false;
299
+ }
300
+
301
+ linear_idx += blockDim.z * gridDim.z;
302
+ compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx);
303
+ }
304
+
305
+ // Perform a batch of reduction operations
306
+ CUTLASS_PRAGMA_UNROLL
307
+ for (int b = 0; b < kBatchSize; ++b) {
308
+ if (guards[b]) {
309
+
310
+ auto cvt = convert_source(source_fragment[b]);
311
+
312
+ accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator(
313
+ reduction_op,
314
+ accumulator,
315
+ cvt);
316
+ }
317
+ }
318
+ };
319
+
320
+ // Optional reduction within a CTA
321
+ if (blockDim.z > 1) {
322
+
323
+ // Linearized thread ID
324
+ int thread_idx = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
325
+
326
+ // all threads store to workspace
327
+ ComputeFragment *frag_ptr = reinterpret_cast<ComputeFragment *>(threadblock_workspace);
328
+
329
+ frag_ptr[thread_idx] = accumulator;
330
+
331
+ __syncthreads();
332
+
333
+ if (threadIdx.z == 0) {
334
+ // Load all additional block indices
335
+ for (int z = 1; z < blockDim.z; ++z) {
336
+ ComputeFragment frag = frag_ptr[thread_idx + z * blockDim.x * blockDim.y];
337
+
338
+ accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator(
339
+ reduction_op,
340
+ accumulator,
341
+ frag);
342
+ }
343
+ }
344
+
345
+ __syncthreads();
346
+ }
347
+
348
+ return accumulator;
349
+ }
350
+
351
+ public:
352
+
353
+ /// Perform a reduction
354
+ CUTLASS_DEVICE
355
+ void operator()(Params const &params, SharedStorage &shared_storage) {
356
+
357
+ int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength;
358
+
359
+ char const * src_byte_ptr = reinterpret_cast<char const *>(params.source + coord_c);
360
+ char * dst_byte_ptr = nullptr;
361
+
362
+ // If performing a reduction across CTAs, redirect output to device workspace
363
+ if (gridDim.z == 1) {
364
+ dst_byte_ptr = reinterpret_cast<char *>(params.destination + coord_c);
365
+ }
366
+ else {
367
+ dst_byte_ptr = reinterpret_cast<char *>(params.device_workspace + coord_c);
368
+ }
369
+
370
+ // If the C index is out of bounds, exit
371
+ if (coord_c >= params.extent[kRank - 1]) {
372
+ return;
373
+ }
374
+
375
+ int64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y;
376
+
377
+ // Use modulo division to compute location
378
+ Coord<kReducedRank - 1> outer_coord;
379
+ int64_t dst_byte_offset;
380
+ int64_t src_byte_offset;
381
+
382
+ compute_outer_coord_and_offset_(
383
+ params,
384
+ outer_coord,
385
+ dst_byte_offset,
386
+ src_byte_offset,
387
+ idx_linear);
388
+
389
+ if (gridDim.z == 1) {
390
+
391
+ /// Complete the reduction with no workspace
392
+ while (idx_linear < params.outer_count) {
393
+
394
+ ComputeFragment result;
395
+
396
+ result = reduce_indices_(
397
+ params,
398
+ shared_storage.workspace.data(),
399
+ src_byte_ptr + src_byte_offset);
400
+
401
+ // Store the result after possible final reduction within the CTA
402
+ if (threadIdx.z == 0) {
403
+
404
+ // Convert to output type and store
405
+ NumericArrayConverter<ElementOutput, ElementCompute, VectorLength> convert_output;
406
+ auto cvt = convert_output(result);
407
+
408
+ *reinterpret_cast<OutputFragment *>(dst_byte_ptr + dst_byte_offset) =
409
+ reinterpret_cast<OutputFragment const &>(cvt);
410
+ }
411
+
412
+ // Update indices and pointers
413
+ idx_linear += gridDim.y * blockDim.y;
414
+
415
+ compute_outer_coord_and_offset_(
416
+ params,
417
+ outer_coord,
418
+ dst_byte_offset,
419
+ src_byte_offset,
420
+ idx_linear);
421
+
422
+ } // while
423
+ }
424
+ else {
425
+
426
+ /// Complete the reduction with a device workspace
427
+ while (idx_linear < params.outer_count) {
428
+
429
+ ComputeFragment result;
430
+
431
+ result = reduce_indices_(
432
+ params,
433
+ shared_storage.workspace.data(),
434
+ src_byte_ptr + src_byte_offset);
435
+
436
+ // Store the result after possible final reduction within the CTA
437
+ if (threadIdx.z == 0) {
438
+
439
+ int64_t byte_offset =
440
+ blockIdx.z * params.workspace_stride + idx_linear * params.workspace_outer_stride;
441
+
442
+ // No conversion - store in compute type
443
+ *reinterpret_cast<ComputeFragment *>(dst_byte_ptr + byte_offset) =
444
+ reinterpret_cast<ComputeFragment const &>(result);
445
+ }
446
+
447
+ // Update indices and pointers
448
+ idx_linear += gridDim.y * blockDim.y;
449
+
450
+ compute_outer_coord_and_offset_(
451
+ params,
452
+ outer_coord,
453
+ dst_byte_offset,
454
+ src_byte_offset,
455
+ idx_linear);
456
+
457
+ } // while (outer index)
458
+ } // if ()
459
+ }
460
+ };
461
+
462
+ /////////////////////////////////////////////////////////////////////////////////////////////////
463
+
464
+ /// Kernel to perform final reduction
465
+ template <
466
+ int Rank, ///< Rank of source tensor (e.g. NDHWC => 5)
467
+ int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2)
468
+ typename ElementOutput, ///< Data type of output tensor
469
+ typename ElementSource, ///< Data type of source tensor
470
+ typename ReductionOp, ///< Reduction operator
471
+ int VectorLength = 1, ///< Vector length for memory
472
+ typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation
473
+ int Threads = 256, ///< Number of participating threads
474
+ int BatchSize = 4 ///< Number of elements to load per batch
475
+ >
476
+ class TensorReductionAffineStridedFinal {
477
+ public:
478
+
479
+ static int const kRank = Rank;
480
+ static int const kReducedRank = ReducedRank;
481
+ static int const kVectorLength = VectorLength;
482
+ static int const kInnerRank = kRank - kReducedRank;
483
+ static int const kThreads = Threads;
484
+ static int const kBatchSize = BatchSize;
485
+ using ComputeFragment = Array<ElementCompute, VectorLength>;
486
+ using SourceFragment = AlignedArray<ElementSource, VectorLength>;
487
+ using OutputFragment = AlignedArray<ElementOutput, VectorLength>;
488
+
489
+ /// Shared memory
490
+ struct SharedStorage { };
491
+
492
+ /// Parameters structure
493
+ using Params = TensorReductionAffineStridedParams<
494
+ Rank,
495
+ ReducedRank,
496
+ ElementOutput,
497
+ ElementSource,
498
+ ReductionOp,
499
+ VectorLength,
500
+ ElementCompute,
501
+ Threads,
502
+ BatchSize
503
+ >;
504
+
505
+ private:
506
+
507
+ /// Computes the coordinate and offset of a given linear index
508
+ CUTLASS_DEVICE
509
+ void compute_outer_coord_and_offset_(
510
+ Params const &params,
511
+ Coord<kReducedRank - 1> & coord,
512
+ int64_t &dst_offset,
513
+ uint64_t linear_idx) const {
514
+
515
+ // Decompose linear index
516
+ coord = CoordinateDecomposition<kReducedRank - 1>(linear_idx, params.divmod);
517
+
518
+ // Compute tensor offset
519
+ dst_offset = 0;
520
+
521
+ CUTLASS_PRAGMA_UNROLL
522
+ for (int i = 0; i < kReducedRank - 1; ++i) {
523
+ dst_offset += params.dst_stride[i] * coord[i];
524
+ }
525
+ }
526
+
527
+ /// Reduces over the reduction indices
528
+ CUTLASS_DEVICE
529
+ ComputeFragment reduce_indices_(
530
+ Params const &params,
531
+ char *src_byte_ptr) {
532
+
533
+ ReductionOp reduction_op(params.reduction_op);
534
+
535
+ // Accumulated output
536
+ ComputeFragment identity_frag;
537
+
538
+ CUTLASS_PRAGMA_UNROLL
539
+ for (int i = 0; i < int(identity_frag.size()); ++i) {
540
+ identity_frag[i] = params.reduction_identity;
541
+ }
542
+
543
+ ComputeFragment accumulator = identity_frag;
544
+ ComputeFragment workspace_fragments[kBatchSize];
545
+
546
+ // Partially unrolled loop
547
+ for (int idx = 0; idx < params.workspace_count; idx += kBatchSize) {
548
+
549
+ // Issue a batch of loads
550
+ CUTLASS_PRAGMA_UNROLL
551
+ for (int b = 0; b < kBatchSize; ++b) {
552
+ if (idx + b < params.workspace_count) {
553
+ workspace_fragments[b] =
554
+ *reinterpret_cast<ComputeFragment *>(src_byte_ptr);
555
+ }
556
+ else {
557
+ workspace_fragments[b] = identity_frag;
558
+ }
559
+ src_byte_ptr += + params.workspace_stride;
560
+ }
561
+
562
+ // Perform a reduction
563
+ CUTLASS_PRAGMA_UNROLL
564
+ for (int b = 0; b < kBatchSize; ++b) {
565
+ CUTLASS_PRAGMA_UNROLL
566
+ for (int i = 0; i < kVectorLength; ++i) {
567
+ accumulator[i] = reduction_op(accumulator[i], workspace_fragments[b][i]);
568
+ }
569
+ }
570
+ }
571
+
572
+ return accumulator;
573
+ }
574
+
575
+ public:
576
+
577
+ //
578
+ // Methods
579
+ //
580
+
581
+ /// Perform a reduction
582
+ CUTLASS_DEVICE
583
+ void operator()(Params const &params, SharedStorage &shared_storage) {
584
+
585
+ int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength;
586
+
587
+ char * src_byte_ptr = reinterpret_cast<char *>(params.device_workspace + coord_c);
588
+ char * dst_byte_ptr = reinterpret_cast<char *>(params.destination + coord_c);
589
+
590
+ // If the C index is out of bounds, exit
591
+ if (coord_c >= params.extent[kRank - 1]) {
592
+ return;
593
+ }
594
+
595
+ int64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y;
596
+
597
+ // Use modulo division to compute location
598
+ Coord<kReducedRank - 1> outer_coord;
599
+ int64_t dst_byte_offset;
600
+
601
+ compute_outer_coord_and_offset_(
602
+ params,
603
+ outer_coord,
604
+ dst_byte_offset,
605
+ idx_linear);
606
+
607
+ /// Complete the reduction
608
+ while (idx_linear < params.outer_count) {
609
+
610
+ int64_t src_byte_offset = idx_linear * params.workspace_outer_stride;
611
+
612
+ ComputeFragment result = reduce_indices_(
613
+ params,
614
+ src_byte_ptr + src_byte_offset);
615
+
616
+ // Convert to output type and store
617
+ NumericArrayConverter<ElementOutput, ElementCompute, VectorLength> convert_output;
618
+ auto cvt = convert_output(result);
619
+
620
+ *reinterpret_cast<OutputFragment *>(dst_byte_ptr + dst_byte_offset) =
621
+ reinterpret_cast<OutputFragment const &>(cvt);
622
+
623
+ // Update indices and pointers
624
+ idx_linear += gridDim.y * blockDim.y;
625
+
626
+ compute_outer_coord_and_offset_(
627
+ params,
628
+ outer_coord,
629
+ dst_byte_offset,
630
+ idx_linear);
631
+ }
632
+ }
633
+ };
634
+
635
+ /////////////////////////////////////////////////////////////////////////////////////////////////
636
+
637
+ } // namespace kernel
638
+ } // namespace reduction
639
+ } // namespace cutlass
640
+
641
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduce.h ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Defines basic thread level reduction with specializations for Array<T, N>.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/numeric_types.h"
39
+ #include "cutlass/array.h"
40
+ #include "cutlass/half.h"
41
+ #include "cutlass/functional.h"
42
+
43
+ namespace cutlass {
44
+ namespace reduction {
45
+ namespace thread {
46
+
47
+ /// Structure to compute the thread level reduction
48
+ template <typename Op, typename T>
49
+ struct Reduce;
50
+
51
+ /////////////////////////////////////////////////////////////////////////////////////////////////
52
+
53
+ /// Partial Specialization of Reduce for "plus" (a functional operator)
54
+ template <typename T>
55
+ struct Reduce< plus<T>, T > {
56
+
57
+ CUTLASS_HOST_DEVICE
58
+ T operator()(T lhs, T const &rhs) const {
59
+ plus<T> _op;
60
+ return _op(lhs, rhs);
61
+ }
62
+ };
63
+
64
+ /////////////////////////////////////////////////////////////////////////////////////////////////
65
+
66
+ /// Partial specialization of Reduce for Array<T, N>
67
+ template <typename T, int N>
68
+ struct Reduce < plus<T>, Array<T, N>> {
69
+
70
+ CUTLASS_HOST_DEVICE
71
+ Array<T, 1> operator()(Array<T, N> const &in) const {
72
+
73
+ Array<T, 1> result;
74
+ Reduce< plus<T>, T > scalar_reduce;
75
+ result.clear();
76
+
77
+ CUTLASS_PRAGMA_UNROLL
78
+ for (auto i = 0; i < N; ++i) {
79
+ result[0] = scalar_reduce(result[0], in[i]);
80
+ }
81
+
82
+ return result;
83
+ }
84
+ };
85
+
86
+ /////////////////////////////////////////////////////////////////////////////////////////////////
87
+
88
+ /// Partial specializations of Reduce for Array<half_t, N>
89
+ template <int N>
90
+ struct Reduce < plus<half_t>, Array<half_t, N> > {
91
+
92
+ CUTLASS_HOST_DEVICE
93
+ Array<half_t, 1> operator()(Array<half_t, N> const &input) {
94
+
95
+ Array<half_t, 1> result;
96
+
97
+ // If there is only 1 element - there is nothing to reduce
98
+ if( N ==1 ){
99
+
100
+ result[0] = input.front();
101
+
102
+ } else {
103
+
104
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)
105
+
106
+ __half result_d;
107
+ Array<half_t, 1> const *in_ptr_half = reinterpret_cast<Array<half_t, 1> const *>(&input);
108
+ Array<half_t, 2> const *in_ptr_half2 = reinterpret_cast<Array<half_t, 2> const *>(&input);
109
+ __half2 const *x_in_half2 = reinterpret_cast<__half2 const *>(in_ptr_half2);
110
+
111
+ // Set initial result = first half2, in case N==2
112
+ __half2 tmp_result = x_in_half2[0];
113
+
114
+ CUTLASS_PRAGMA_UNROLL
115
+ for (int i = 1; i < N/2; ++i) {
116
+
117
+ tmp_result = __hadd2(x_in_half2[i], tmp_result);
118
+
119
+ }
120
+
121
+ result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result));
122
+
123
+ // One final step is needed for odd "N" (to add the (N-1)th element)
124
+ if( N%2 ){
125
+
126
+ __half last_element;
127
+ Array<half_t, 1> tmp_last;
128
+ Array<half_t, 1> *tmp_last_ptr = &tmp_last;
129
+ tmp_last_ptr[0] = in_ptr_half[N-1];
130
+ last_element = reinterpret_cast<__half const &>(tmp_last);
131
+
132
+ result_d = __hadd(result_d, last_element);
133
+
134
+ }
135
+
136
+ Array<half_t, 1> *result_ptr = &result;
137
+ *result_ptr = reinterpret_cast<Array<half_t, 1> &>(result_d);
138
+
139
+ #else
140
+
141
+ Reduce< plus<half_t>, half_t > scalar_reduce;
142
+ result.clear();
143
+
144
+ CUTLASS_PRAGMA_UNROLL
145
+ for (auto i = 0; i < N; ++i) {
146
+
147
+ result[0] = scalar_reduce(result[0], input[i]);
148
+
149
+ }
150
+
151
+ #endif
152
+ }
153
+
154
+ return result;
155
+
156
+ }
157
+ };
158
+
159
+
160
+ /////////////////////////////////////////////////////////////////////////////////////////////////
161
+
162
+ /// Partial specializations of Reduce for AlignedArray<half_t, N>
163
+ template <int N>
164
+ struct Reduce < plus<half_t>, AlignedArray<half_t, N> > {
165
+
166
+ CUTLASS_HOST_DEVICE
167
+ Array<half_t, 1> operator()(AlignedArray<half_t, N> const &input) {
168
+
169
+ Array<half_t, 1> result;
170
+
171
+ // If there is only 1 element - there is nothing to reduce
172
+ if( N ==1 ){
173
+
174
+ result[0] = input.front();
175
+
176
+ } else {
177
+
178
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)
179
+
180
+ __half result_d;
181
+ AlignedArray<half_t, 1> const *in_ptr_half = reinterpret_cast<AlignedArray<half_t, 1> const *>(&input);
182
+ AlignedArray<half_t, 2> const *in_ptr_half2 = reinterpret_cast<AlignedArray<half_t, 2> const *>(&input);
183
+ __half2 const *x_in_half2 = reinterpret_cast<__half2 const *>(in_ptr_half2);
184
+
185
+ // Set initial result = first half2, in case N==2
186
+ __half2 tmp_result = x_in_half2[0];
187
+
188
+ CUTLASS_PRAGMA_UNROLL
189
+ for (int i = 1; i < N/2; ++i) {
190
+
191
+ tmp_result = __hadd2(x_in_half2[i], tmp_result);
192
+
193
+ }
194
+
195
+ result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result));
196
+
197
+ // One final step is needed for odd "N" (to add the (N-1)th element)
198
+ if( N%2 ){
199
+
200
+ __half last_element;
201
+ AlignedArray<half_t, 1> tmp_last;
202
+ AlignedArray<half_t, 1> *tmp_last_ptr = &tmp_last;
203
+ tmp_last_ptr[0] = in_ptr_half[N-1];
204
+ last_element = reinterpret_cast<__half const &>(tmp_last);
205
+
206
+ result_d = __hadd(result_d, last_element);
207
+
208
+ }
209
+
210
+ Array<half_t, 1> *result_ptr = &result;
211
+ *result_ptr = reinterpret_cast<Array<half_t, 1> &>(result_d);
212
+
213
+ #else
214
+
215
+ Reduce< plus<half_t>, half_t > scalar_reduce;
216
+ result.clear();
217
+
218
+ CUTLASS_PRAGMA_UNROLL
219
+ for (auto i = 0; i < N; ++i) {
220
+
221
+ result[0] = scalar_reduce(result[0], input[i]);
222
+
223
+ }
224
+
225
+ #endif
226
+ }
227
+
228
+ return result;
229
+
230
+ }
231
+ };
232
+ }
233
+ }
234
+ }
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduction_operators.h ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Kernel performing a reduction over densely packed tensors in global memory
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/tensor_ref.h"
39
+ #include "cutlass/numeric_types.h"
40
+ #include "cutlass/array.h"
41
+ #include "cutlass/functional.h"
42
+ #include "cutlass/numeric_conversion.h"
43
+
44
+ /////////////////////////////////////////////////////////////////////////////////////////////////
45
+
46
+ namespace cutlass {
47
+ namespace reduction {
48
+ namespace thread {
49
+
50
+ /////////////////////////////////////////////////////////////////////////////////////////////////
51
+
52
+ /// Mixed-precision reduction
53
+ template <
54
+ typename ElementAccumulator_,
55
+ typename Element_,
56
+ int Count = 1
57
+ >
58
+ struct ReduceAdd {
59
+
60
+ //
61
+ // Type definitions
62
+ //
63
+
64
+ using ElementAccumulator = ElementAccumulator_;
65
+ using Element = Element_;
66
+ static int const kCount = Count;
67
+
68
+ using FragmentAccumulator = cutlass::Array<ElementAccumulator, kCount>;
69
+ using FragmentElement = cutlass::Array<Element, kCount>;
70
+
71
+ struct Params { };
72
+
73
+ //
74
+ // Data members
75
+ //
76
+
77
+ /// Parameters object
78
+ Params params;
79
+
80
+ //
81
+ // Methods
82
+ //
83
+
84
+ /// Constructor
85
+ CUTLASS_HOST_DEVICE
86
+ ReduceAdd(Params params_ = Params()): params(params_) { }
87
+
88
+ /// Operator
89
+ CUTLASS_HOST_DEVICE
90
+ FragmentAccumulator operator()(
91
+ FragmentAccumulator accumulator,
92
+ FragmentElement element) const {
93
+
94
+ plus<FragmentAccumulator> op;
95
+
96
+ NumericArrayConverter<
97
+ ElementAccumulator,
98
+ Element,
99
+ kCount,
100
+ PreferredRoundingMode<ElementAccumulator, Element>::kRound> converter;
101
+
102
+ return op(accumulator, converter(element));
103
+ }
104
+ };
105
+
106
+ /////////////////////////////////////////////////////////////////////////////////////////////////
107
+
108
+ namespace detail {
109
+
110
+ /// Special handling for binary operators
111
+ template <typename ReductionOp, typename Element, int N>
112
+ struct VectorizeArrayOperation {
113
+
114
+ using ValueType = Array<Element, N>;
115
+
116
+ CUTLASS_HOST_DEVICE
117
+ ValueType operator()(
118
+ ReductionOp const &reduction_op,
119
+ ValueType const &lhs,
120
+ ValueType const &rhs) const {
121
+
122
+ ValueType result;
123
+
124
+ CUTLASS_PRAGMA_UNROLL
125
+ for (int i = 0; i < N; ++i) {
126
+ result[i] = reduction_op(lhs[i], rhs[i]);
127
+ }
128
+
129
+ return result;
130
+ }
131
+ };
132
+
133
+ /////////////////////////////////////////////////////////////////////////////////////////////////
134
+
135
+ template <typename ReductionOp, typename Element, int N>
136
+ struct ReduceArrayOperation {
137
+
138
+ using ArrayType = Array<Element, N>;
139
+
140
+ CUTLASS_HOST_DEVICE
141
+ Element operator()(
142
+ ReductionOp const &reduction_op,
143
+ ArrayType const &array) const {
144
+
145
+ Element item = reduction_op(array[0], array[1]);
146
+
147
+ CUTLASS_PRAGMA_UNROLL
148
+ for (int i = 2; i < N; ++i) {
149
+ item = reduction_op(item, array[i]);
150
+ }
151
+
152
+ return item;
153
+ }
154
+ };
155
+
156
+ template <int N>
157
+ struct ReduceArrayOperation<logical_and<uint1b_t>, uint1b_t, N> {
158
+
159
+ using ArrayType = Array<uint1b_t, N>;
160
+
161
+ CUTLASS_HOST_DEVICE
162
+ uint1b_t operator()(
163
+ logical_and<uint1b_t> const &reduction_op,
164
+ ArrayType const &array) const {
165
+
166
+ uint8_t const *ptr = reinterpret_cast<uint8_t const *>(&array);
167
+ bool item = false;
168
+
169
+ CUTLASS_PRAGMA_UNROLL
170
+ for (int byte = 0; byte < (N + 7) / 8; ++byte) {
171
+ uint8_t bits = ptr[byte];
172
+ item = (item || !bits);
173
+ }
174
+
175
+ return uint1b_t{!item};
176
+ }
177
+ };
178
+
179
+ template <int N>
180
+ struct ReduceArrayOperation<logical_or<uint1b_t>, uint1b_t, N> {
181
+
182
+ using ArrayType = Array<uint1b_t, N>;
183
+
184
+ CUTLASS_HOST_DEVICE
185
+ uint1b_t operator()(
186
+ logical_and<uint1b_t> const &reduction_op,
187
+ ArrayType const &array) const {
188
+
189
+ uint8_t const *ptr = reinterpret_cast<uint8_t const *>(&array);
190
+ bool item = true;
191
+
192
+ CUTLASS_PRAGMA_UNROLL
193
+ for (int byte = 0; byte < (N + 7) / 8; ++byte) {
194
+ uint8_t bits = ptr[byte];
195
+ item = (item || bits);
196
+ }
197
+
198
+ return uint1b_t{item};
199
+ }
200
+ };
201
+
202
+ /////////////////////////////////////////////////////////////////////////////////////////////////
203
+
204
+ /// Helper function to infer template argument types
205
+ template <typename ReductionOp, typename Element, int N>
206
+ CUTLASS_HOST_DEVICE
207
+ Array<Element, N> ApplyArrayOperator(
208
+ ReductionOp const &reduction_op,
209
+ Array<Element, N> const &lhs,
210
+ Array<Element, N> const &rhs) {
211
+
212
+ VectorizeArrayOperation<ReductionOp, Element, N> vectorize_op;
213
+
214
+ return vectorize_op(reduction_op, lhs, rhs);
215
+ }
216
+
217
+ /// Helper to reduce an array
218
+ template <typename ReductionOp, typename Element, int N>
219
+ Element ReduceArray(ReductionOp const &reduction_op, Array<Element, N> const &array) {
220
+ ReduceArrayOperation<ReductionOp, Element, N> reduce_array_op;
221
+
222
+ return reduce_array_op(reduction_op, array);
223
+ }
224
+
225
+ /////////////////////////////////////////////////////////////////////////////////////////////////
226
+
227
+ } // namespace detail
228
+
229
+ /////////////////////////////////////////////////////////////////////////////////////////////////
230
+
231
+ } // namespace thread
232
+ } // namespace reduction
233
+ } // namespace cutlass
234
+
235
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/threadblock_swizzle.h ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Defies functors for mapping blockIdx to partitions of the batched reduction computation.
33
+ */
34
+ #pragma once
35
+ #include "cutlass/coord.h"
36
+
37
+ namespace cutlass {
38
+ namespace reduction {
39
+ struct DefaultBlockSwizzle {
40
+ /// Ctor
41
+ CUTLASS_HOST_DEVICE DefaultBlockSwizzle() {}
42
+
43
+ /// Swizzle the block index.
44
+ CUTLASS_DEVICE dim3 swizzle() { return blockIdx; }
45
+
46
+ ///
47
+ CUTLASS_HOST_DEVICE dim3 get_grid_layout(Coord<3> const &problem_size,
48
+ Coord<3> const &OutputTile) {
49
+ assert(OutputTile[0] == 1 && OutputTile[1] == 1);
50
+ assert((problem_size[0] * problem_size[1] * problem_size[2]) % OutputTile[2] == 0);
51
+ dim3 grid;
52
+ grid.x = problem_size[0] * problem_size[1] * problem_size[2]
53
+ / OutputTile[2] ;
54
+ return grid;
55
+ }
56
+
57
+ ///
58
+ CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &SubTile) {
59
+ assert(SubTile[0] == 1 && SubTile[1] == 1);
60
+ dim3 block = swizzle();
61
+ Coord<3> threadblock_offset =
62
+ make_Coord(0, 0, block.x * SubTile[2]);
63
+ return threadblock_offset;
64
+ }
65
+ };
66
+ } // namespace reduction
67
+ } // namespace cutlass
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/relatively_equal.h ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /* \file
32
+ \brief Performs comparison between two elements with support for floating-point comparisons.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "numeric_types.h"
38
+ #include "complex.h"
39
+
40
+ namespace cutlass {
41
+
42
+ /////////////////////////////////////////////////////////////////////////////////////////////////
43
+
44
+ template <typename T, typename U = T>
45
+ CUTLASS_HOST_DEVICE
46
+ bool relatively_equal(T a, T b, U epsilon, U nonzero_floor);
47
+
48
+ /////////////////////////////////////////////////////////////////////////////////////////////////
49
+
50
+ namespace detail {
51
+
52
+ // This floating-point comparison function implements the method described in
53
+ //
54
+ // https://floating-point-gui.de/errors/comparison/
55
+ //
56
+ template <typename T>
57
+ CUTLASS_HOST_DEVICE
58
+ bool relatively_equal_float(T a, T b, T epsilon, T nonzero_floor) {
59
+
60
+ #if defined(__CUDACC_RTC__)
61
+ using cuda::std::abs;
62
+ #else
63
+ using std::abs;
64
+ #endif
65
+
66
+ T abs_A = abs(a);
67
+ T abs_B = abs(b);
68
+ T diff = abs(a - b);
69
+ T zero = T(0);
70
+
71
+ if (a == b) {
72
+ return true;
73
+ }
74
+ else if (a == zero || b == zero || (abs_A + abs_B) < nonzero_floor) {
75
+ return diff < epsilon * nonzero_floor;
76
+ }
77
+
78
+ return diff < epsilon * (abs_A + abs_B);
79
+ }
80
+
81
+ } // namespace detail
82
+
83
+ /////////////////////////////////////////////////////////////////////////////////////////////////
84
+
85
+ template <>
86
+ CUTLASS_HOST_DEVICE
87
+ bool relatively_equal<bool>(bool a, bool b, bool, bool) {
88
+ return (a == b);
89
+ }
90
+
91
+ template <>
92
+ CUTLASS_HOST_DEVICE
93
+ bool relatively_equal<uint1b_t>(uint1b_t a, uint1b_t b, uint1b_t, uint1b_t) {
94
+ return (a == b);
95
+ }
96
+
97
+ template <>
98
+ CUTLASS_HOST_DEVICE
99
+ bool relatively_equal<int2b_t>(int2b_t a, int2b_t b, int2b_t, int2b_t) {
100
+ return (a == b);
101
+ }
102
+
103
+ template <>
104
+ CUTLASS_HOST_DEVICE
105
+ bool relatively_equal<uint2b_t>(uint2b_t a, uint2b_t b, uint2b_t, uint2b_t) {
106
+ return (a == b);
107
+ }
108
+
109
+ template <>
110
+ CUTLASS_HOST_DEVICE
111
+ bool relatively_equal<int4b_t>(int4b_t a, int4b_t b, int4b_t, int4b_t) {
112
+ return (a == b);
113
+ }
114
+
115
+ template <>
116
+ CUTLASS_HOST_DEVICE
117
+ bool relatively_equal<uint4b_t>(uint4b_t a, uint4b_t b, uint4b_t, uint4b_t) {
118
+ return (a == b);
119
+ }
120
+
121
+ template <>
122
+ CUTLASS_HOST_DEVICE
123
+ bool relatively_equal<int8_t>(int8_t a, int8_t b, int8_t, int8_t) {
124
+ return (a == b);
125
+ }
126
+
127
+ template <>
128
+ CUTLASS_HOST_DEVICE
129
+ bool relatively_equal<uint8_t>(uint8_t a, uint8_t b, uint8_t, uint8_t) {
130
+ return (a == b);
131
+ }
132
+
133
+ template <>
134
+ CUTLASS_HOST_DEVICE
135
+ bool relatively_equal<int16_t>(int16_t a, int16_t b, int16_t, int16_t) {
136
+ return (a == b);
137
+ }
138
+
139
+ template <>
140
+ CUTLASS_HOST_DEVICE
141
+ bool relatively_equal<uint16_t>(uint16_t a, uint16_t b, uint16_t, uint16_t) {
142
+ return (a == b);
143
+ }
144
+
145
+ template <>
146
+ CUTLASS_HOST_DEVICE
147
+ bool relatively_equal<int32_t>(int32_t a, int32_t b, int32_t, int32_t) {
148
+ return (a == b);
149
+ }
150
+
151
+ template <>
152
+ CUTLASS_HOST_DEVICE
153
+ bool relatively_equal<uint32_t>(uint32_t a, uint32_t b, uint32_t, uint32_t) {
154
+ return (a == b);
155
+ }
156
+
157
+ template <>
158
+ CUTLASS_HOST_DEVICE
159
+ bool relatively_equal<int64_t>(int64_t a, int64_t b, int64_t, int64_t) {
160
+ return (a == b);
161
+ }
162
+
163
+ template <>
164
+ CUTLASS_HOST_DEVICE
165
+ bool relatively_equal<uint64_t>(uint64_t a, uint64_t b, uint64_t, uint64_t) {
166
+ return (a == b);
167
+ }
168
+
169
+ /////////////////////////////////////////////////////////////////////////////////////////////////
170
+
171
+ template <>
172
+ CUTLASS_HOST_DEVICE
173
+ bool relatively_equal<float_e4m3_t>(float_e4m3_t a, float_e4m3_t b, float_e4m3_t epsilon, float_e4m3_t nonzero_floor) {
174
+ return detail::relatively_equal_float<float>(a, b, epsilon, nonzero_floor);
175
+ }
176
+
177
+ template <>
178
+ CUTLASS_HOST_DEVICE
179
+ bool relatively_equal<float_e5m2_t>(float_e5m2_t a, float_e5m2_t b, float_e5m2_t epsilon, float_e5m2_t nonzero_floor) {
180
+ return detail::relatively_equal_float<float>(a, b, epsilon, nonzero_floor);
181
+ }
182
+
183
+ template <>
184
+ CUTLASS_HOST_DEVICE
185
+ bool relatively_equal<half_t>(half_t a, half_t b, half_t epsilon, half_t nonzero_floor) {
186
+ return detail::relatively_equal_float(a, b, epsilon, nonzero_floor);
187
+ }
188
+
189
+ template <>
190
+ CUTLASS_HOST_DEVICE
191
+ bool relatively_equal<bfloat16_t>(
192
+ bfloat16_t a,
193
+ bfloat16_t b,
194
+ bfloat16_t epsilon,
195
+ bfloat16_t nonzero_floor) {
196
+
197
+ return detail::relatively_equal_float(a, b, epsilon, nonzero_floor);
198
+ }
199
+
200
+ template <>
201
+ CUTLASS_HOST_DEVICE
202
+ bool relatively_equal<tfloat32_t>(
203
+ tfloat32_t a,
204
+ tfloat32_t b,
205
+ tfloat32_t epsilon,
206
+ tfloat32_t nonzero_floor) {
207
+
208
+ return detail::relatively_equal_float(a, b, epsilon, nonzero_floor);
209
+ }
210
+
211
+ template <>
212
+ CUTLASS_HOST_DEVICE
213
+ bool relatively_equal<float>(float a, float b, float epsilon, float nonzero_floor) {
214
+ return detail::relatively_equal_float(a, b, epsilon, nonzero_floor);
215
+ }
216
+
217
+
218
+ template <>
219
+ CUTLASS_HOST_DEVICE
220
+ bool relatively_equal<double>(double a, double b, double epsilon, double nonzero_floor) {
221
+ return detail::relatively_equal_float(a, b, epsilon, nonzero_floor);
222
+ }
223
+
224
+ template<typename T>
225
+ CUTLASS_HOST_DEVICE
226
+ bool relatively_equal(complex<T> a, complex<T> b, T epsilon, T nonzero_floor) {
227
+ #if defined(__CUDACC_RTC__)
228
+ using cuda::std::abs;
229
+ #else
230
+ using std::abs;
231
+ #endif
232
+
233
+ T abs_A = abs(a);
234
+ T abs_B = abs(b);
235
+ T diff = abs(a - b);
236
+ complex<T> zero = complex<T>{T{}, T{}};
237
+
238
+ if (a == b) {
239
+ return true;
240
+ }
241
+ else if (a == zero || b == zero || diff < nonzero_floor) {
242
+ return diff < epsilon * nonzero_floor;
243
+ }
244
+
245
+ return diff < epsilon * (abs_A + abs_B);
246
+ }
247
+
248
+ template <typename T>
249
+ CUTLASS_HOST_DEVICE
250
+ bool relatively_equal(complex<T> a, complex<T> b, complex<T> epsilon, complex<T> nonzero_floor) {
251
+ #if defined(__CUDACC_RTC__)
252
+ using cuda::std::abs;
253
+ #else
254
+ using std::abs;
255
+ #endif
256
+
257
+ T abs_A = abs(a);
258
+ T abs_B = abs(b);
259
+ complex<T> diff = a - b;
260
+ T abs_diff = abs(diff);
261
+ complex<T> zero = complex<T>{T{}, T{}};
262
+
263
+ if (a == b) {
264
+ return true;
265
+ }
266
+ else if (a == zero || b == zero || abs_diff < abs(nonzero_floor)) {
267
+ return abs_diff < abs(epsilon * nonzero_floor);
268
+ }
269
+
270
+ return abs_diff < abs(epsilon) * (abs_A + abs_B);
271
+ }
272
+
273
+
274
+ template <>
275
+ CUTLASS_HOST_DEVICE
276
+ bool relatively_equal<float_e2m3_t>(float_e2m3_t a, float_e2m3_t b, float_e2m3_t epsilon, float_e2m3_t nonzero_floor) {
277
+ return detail::relatively_equal_float<float>(a, b, epsilon, nonzero_floor);
278
+ }
279
+
280
+ template <>
281
+ CUTLASS_HOST_DEVICE
282
+ bool relatively_equal<float_e3m2_t>(float_e3m2_t a, float_e3m2_t b, float_e3m2_t epsilon, float_e3m2_t nonzero_floor) {
283
+ return detail::relatively_equal_float<float>(a, b, epsilon, nonzero_floor);
284
+ }
285
+
286
+ template <>
287
+ CUTLASS_HOST_DEVICE
288
+ bool relatively_equal<float_e2m1_t>(float_e2m1_t a, float_e2m1_t b, float_e2m1_t epsilon, float_e2m1_t nonzero_floor) {
289
+ return detail::relatively_equal_float<float>(a, b, epsilon, nonzero_floor);
290
+ }
291
+ template <>
292
+ CUTLASS_HOST_DEVICE
293
+ bool relatively_equal<float_ue8m0_t>(float_ue8m0_t a, float_ue8m0_t b, float_ue8m0_t epsilon, float_ue8m0_t nonzero_floor) {
294
+ return detail::relatively_equal_float<float>(a, b, epsilon, nonzero_floor);
295
+ }
296
+
297
+ template <>
298
+ CUTLASS_HOST_DEVICE
299
+ bool relatively_equal<float_ue4m3_t>(float_ue4m3_t a, float_ue4m3_t b, float_ue4m3_t epsilon, float_ue4m3_t nonzero_floor) {
300
+ return detail::relatively_equal_float<float>(a, b, epsilon, nonzero_floor);
301
+ }
302
+
303
+ /////////////////////////////////////////////////////////////////////////////////////////////////
304
+
305
+ } // namespace cutlass
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/semaphore.h ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Implementation of a CTA-wide semaphore for inter-CTA synchronization.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+
39
+ #include "cutlass/array.h"
40
+
41
+ #include "cutlass/numeric_types.h"
42
+ #include "cutlass/matrix_shape.h"
43
+
44
+ #include "cutlass/gemm/gemm.h"
45
+
46
+ /////////////////////////////////////////////////////////////////////////////////////////////////
47
+
48
+ namespace cutlass {
49
+
50
+ /////////////////////////////////////////////////////////////////////////////////////////////////
51
+
52
+ /// CTA-wide semaphore for inter-CTA synchronization.
53
+ class Semaphore {
54
+ public:
55
+
56
+ int *lock;
57
+ bool wait_thread;
58
+ int state;
59
+
60
+ public:
61
+
62
+ /// Implements a semaphore to wait for a flag to reach a given value
63
+ CUTLASS_HOST_DEVICE
64
+ Semaphore(int *lock_, int thread_id):
65
+ lock(lock_),
66
+ wait_thread(thread_id < 0 || thread_id == 0),
67
+ state(-1) {
68
+
69
+ }
70
+
71
+ /// Permit fetching the synchronization mechanism early
72
+ CUTLASS_DEVICE
73
+ void fetch() {
74
+ if (wait_thread) {
75
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
76
+ asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
77
+ #else
78
+ asm volatile ("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
79
+ #endif
80
+ }
81
+ }
82
+
83
+ /// Gets the internal state
84
+ CUTLASS_DEVICE
85
+ int get_state() const {
86
+ return state;
87
+ }
88
+
89
+ /// Waits until the semaphore is equal to the given value
90
+ CUTLASS_DEVICE
91
+ void wait(int status = 0) {
92
+ while( __syncthreads_and(state != status) ) {
93
+ fetch();
94
+ }
95
+
96
+ __syncthreads();
97
+ }
98
+
99
+ /// Updates the lock with the given result
100
+ CUTLASS_DEVICE
101
+ void release(int status = 0) {
102
+ __syncthreads();
103
+
104
+ if (wait_thread) {
105
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
106
+ asm volatile ("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
107
+ #else
108
+ asm volatile ("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
109
+ #endif
110
+ }
111
+ }
112
+ };
113
+
114
+ /////////////////////////////////////////////////////////////////////////////////////////////////
115
+
116
+ } // namespace cutlass
117
+
118
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/subbyte_reference.h ADDED
@@ -0,0 +1,1388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Provides a mechanism for packing and unpacking elements smaller than one byte
33
+ */
34
+ #pragma once
35
+
36
+ #include "cutlass/cutlass.h"
37
+ #include "cutlass/integer_subbyte.h"
38
+ #include "cutlass/fast_math.h"
39
+
40
+ namespace cutlass {
41
+
42
+ namespace detail {
43
+ // This is an implementation detail of cutlass::SubbyteReference and.
44
+ // cutlass::HostTensor. For a given logical element type Element,
45
+ // and its corresponding storage (physical) element type StorageUnit,
46
+ // it computes quantities that help with managing allocations.
47
+ //
48
+ // CUTLASS uses a hidden "ContainerUnitType" or StorageUnit type to support
49
+ // packed arrays of subbyte types such as int4. Element is the "logical" type
50
+ // for computations, while CUTLASS uses StorageUnit as the element type
51
+ // of a packed array of Element. If Element is not a subbyte type,
52
+ // then the corresponding StorageUnit type is just Element itself.
53
+ //
54
+ // The ContainerType is always calculated as an array StorageUnit type (the StorageUnit
55
+ // is always a byte for subbyte types),
56
+ // and its number of bits is the lcm of the subbyte type's number of bits and 8.
57
+ // Below are some examples for different subbyte types.
58
+ //
59
+ // * Subbyte Type=int2, ContainerType=StorageUnit[1] (StorageUnit=uint8_t)
60
+ // * Subbyte Type=int4, ContainerType=StorageUnit[1] (StorageUnit=uint8_t)
61
+ template<class Element, class StorageUnit>
62
+ struct StorageContainerCalculator {
63
+ // kContainerTypeNumBits: The number of bits needed for ContainerType
64
+ static constexpr int kContainerTypeNumBits = (sizeof_bits<Element>::value < 8) ? cutlass::lcm_cxx11(sizeof_bits<Element>::value, sizeof_bits<StorageUnit>::value) : sizeof_bits<Element>::value;
65
+ static_assert(kContainerTypeNumBits % sizeof_bits<Element>::value == 0, "The bits of ContainerType should be divisible by the element's number of bits");
66
+ // kContainerTypeNumLogicalElements: The number of logical Element instance(s) that can be stored per ContainerType instance
67
+ static constexpr int kContainerTypeNumLogicalElements = kContainerTypeNumBits / sizeof_bits<Element>::value;
68
+ /// 3. kContainerTypeNumBytes: The number of bytes per ContainerType instance
69
+ static constexpr int kContainerTypeNumBytes = kContainerTypeNumBits / 8;
70
+ /// 4. kContainerTypeNumBytes: The number of base StorageUnit in the ContainerType
71
+ static constexpr int kContainerTypeNumStorageUnit = kContainerTypeNumBits / sizeof_bits<StorageUnit>::value;
72
+
73
+ static_assert(kContainerTypeNumBits != 0, "kContainerTypeNumBits can not be zero");
74
+ static_assert(kContainerTypeNumLogicalElements != 0, "kContainerTypeNumLogicalElements can not be zero");
75
+ static_assert(kContainerTypeNumBytes != 0, "kContainerTypeNumBytes can not be zero");
76
+ };
77
+ }
78
+
79
+ /////////////////////////////////////////////////////////////////////////////////////////////////
80
+
81
+ /// This class provides a mechanism for packing and unpacking elements smaller than one byte. It
82
+ /// assumes these sub-byte elements are packed in a traditional C++ numeric type.
83
+ ///
84
+ /// The intended application is to provide a mechanism to indirectly reference elements in
85
+ /// memory or Array<> objects whose addresses cannot otherwise be taken since they are smaller
86
+ /// than one byte.
87
+ ///
88
+ /// Supports basic pointer arithmetic:
89
+ ///
90
+ /// Example:
91
+ ///
92
+ /// int4b_t *ptr = ...;
93
+ ///
94
+ /// SubbyteReference<int4b_t> ref = ptr;
95
+ /// ref += 15;
96
+ ///
97
+ /// int4b_t x = ref; // load an int4b_t
98
+ /// ref = x + 2_s4; // perform arithmetic on int4b_t and then store
99
+ ///
100
+ template <
101
+ typename Element_, /// CUTLASS numeric element type.
102
+ typename Storage_ = uint8_t, /// Underlying storage type. Must be able to hold an integer
103
+ /// number of objects of type Element.
104
+ class = void
105
+ >
106
+ class ConstSubbyteReference {
107
+ public:
108
+
109
+ using Element = Element_;
110
+ using Storage = Storage_;
111
+ using StoragePointer = Storage const *;
112
+
113
+ static_assert(sizeof_bits<Element>::value <= sizeof_bits<Storage>::value,
114
+ "Size of Element must not be greater than Storage.");
115
+
116
+ static_assert(!(sizeof_bits<Storage>::value % sizeof_bits<Element>::value),
117
+ "Storage must be divisible by Element");
118
+
119
+ private:
120
+
121
+ ///! Number of elements per storage vector
122
+ int const kElementsPerVector = sizeof_bits<Storage>::value / sizeof_bits<Element>::value;
123
+
124
+ ///! Bit mask
125
+ Storage const kMask =
126
+ ((sizeof_bits<Element>::value < sizeof_bits<Storage>::value) ?
127
+ (Storage(1) << sizeof_bits<Element>::value) - Storage(1) :
128
+ ~Storage(0));
129
+
130
+ private:
131
+
132
+ /// Pointer to array containing element
133
+ StoragePointer ptr_;
134
+
135
+ /// Offset (in units of elements) from pointer.
136
+ ///
137
+ /// Invariant: must always be in range [0, kElementsPerVector)
138
+ int offset_;
139
+
140
+ public:
141
+
142
+ CUTLASS_HOST_DEVICE
143
+ ConstSubbyteReference(): ptr_(nullptr), offset_(0) { }
144
+
145
+ /// Constructor
146
+ CUTLASS_HOST_DEVICE
147
+ ConstSubbyteReference(
148
+ Element const *ptr, /// pointer to memory
149
+ int64_t offset /// logical offset in units of Element
150
+ ):
151
+ ptr_(reinterpret_cast<StoragePointer>(ptr)),
152
+ offset_(0) {
153
+
154
+ int64_t offset_in_vectors = offset / kElementsPerVector;
155
+ int64_t offset_in_elements = offset % kElementsPerVector;
156
+
157
+ ptr_ += offset_in_vectors;
158
+ offset_ = int(offset_in_elements);
159
+ }
160
+
161
+ /// Constructor
162
+ CUTLASS_HOST_DEVICE
163
+ ConstSubbyteReference(
164
+ Element *ptr = nullptr
165
+ ): ConstSubbyteReference(ptr, 0) { }
166
+
167
+ /// Gets storage pointer
168
+ CUTLASS_HOST_DEVICE
169
+ StoragePointer storage_pointer() const {
170
+ return ptr_;
171
+ }
172
+
173
+ /// Gets element offset within storage vector
174
+ CUTLASS_HOST_DEVICE
175
+ int element_offset() const {
176
+ return offset_;
177
+ }
178
+
179
+ /// Unpacks an element from memory
180
+ CUTLASS_HOST_DEVICE
181
+ Element get() const {
182
+ Storage item = Storage((*ptr_ >> (offset_ * sizeof_bits<Element>::value)) & kMask);
183
+ return reinterpret_cast<Element const &>(item);
184
+ }
185
+
186
+ /// Unpacks an element from memory
187
+ CUTLASS_HOST_DEVICE
188
+ operator Element() const {
189
+ return get();
190
+ }
191
+
192
+ /// Adds an offset in units of elements to the reference
193
+ CUTLASS_HOST_DEVICE
194
+ ConstSubbyteReference &operator+=(int offset) {
195
+
196
+ offset += offset_;
197
+
198
+ int offset_in_vectors = offset / kElementsPerVector;
199
+ int offset_in_elements = offset % kElementsPerVector;
200
+
201
+ ptr_ += offset_in_vectors;
202
+ offset_ = offset_in_elements;
203
+
204
+ return *this;
205
+ }
206
+
207
+ /// Adds an offset in units of elements to the reference
208
+ CUTLASS_HOST_DEVICE
209
+ ConstSubbyteReference &operator+=(long long offset) {
210
+
211
+ offset += offset_;
212
+
213
+ long long offset_in_vectors = offset / kElementsPerVector;
214
+ int offset_in_elements = int(offset % kElementsPerVector);
215
+
216
+ ptr_ += offset_in_vectors;
217
+ offset_ = offset_in_elements;
218
+
219
+ return *this;
220
+ }
221
+
222
+ /// Adds an offset in units of elements to the reference
223
+ CUTLASS_HOST_DEVICE
224
+ ConstSubbyteReference &operator-=(int offset) {
225
+
226
+ int offset_in_vectors = offset / kElementsPerVector;
227
+ int offset_in_elements = offset % kElementsPerVector;
228
+
229
+ ptr_ -= offset_in_vectors;
230
+ offset_ -= offset_in_elements;
231
+
232
+ if (offset_ < 0) {
233
+ offset_ += kElementsPerVector;
234
+ --ptr_;
235
+ }
236
+
237
+ return *this;
238
+ }
239
+
240
+ /// Adds an offset in units of elements to the reference
241
+ CUTLASS_HOST_DEVICE
242
+ ConstSubbyteReference &operator-=(long long offset) {
243
+
244
+ long long offset_in_vectors = offset / kElementsPerVector;
245
+ int offset_in_elements = int(offset % kElementsPerVector);
246
+
247
+ ptr_ -= offset_in_vectors;
248
+ offset_ -= offset_in_elements;
249
+
250
+ if (offset_ < 0) {
251
+ offset_ += kElementsPerVector;
252
+ --ptr_;
253
+ }
254
+
255
+ return *this;
256
+ }
257
+
258
+ /// Returns a reference to an element with a given offset from the current reference
259
+ CUTLASS_HOST_DEVICE
260
+ ConstSubbyteReference operator+(int offset) const {
261
+
262
+ ConstSubbyteReference ref(ptr_, offset_);
263
+ ref += offset;
264
+
265
+ return ref;
266
+ }
267
+
268
+ /// Returns a reference to an element with a given offset from the current reference
269
+ CUTLASS_HOST_DEVICE
270
+ ConstSubbyteReference operator+(long long offset) const {
271
+
272
+ ConstSubbyteReference ref(ptr_, offset_);
273
+ ref += offset;
274
+
275
+ return ref;
276
+ }
277
+
278
+ /// Returns a reference to an element with a given offset from the current reference
279
+ CUTLASS_HOST_DEVICE
280
+ ConstSubbyteReference operator-(int offset) const {
281
+
282
+ ConstSubbyteReference ref(ptr_, offset_);
283
+ ref -= offset;
284
+
285
+ return ref;
286
+ }
287
+
288
+ /// Returns a reference to an element with a given offset from the current reference
289
+ CUTLASS_HOST_DEVICE
290
+ ConstSubbyteReference operator-=(long long offset) const {
291
+
292
+ ConstSubbyteReference ref(ptr_, offset_);
293
+ ref -= offset;
294
+
295
+ return ref;
296
+ }
297
+
298
+ /// Computes the difference in elements between references
299
+ CUTLASS_HOST_DEVICE
300
+ ptrdiff_t operator-(ConstSubbyteReference ref) const {
301
+ return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_);
302
+ }
303
+
304
+ /// Explicit cast to int
305
+ CUTLASS_HOST_DEVICE
306
+ explicit operator int() const {
307
+ return int(get());
308
+ }
309
+
310
+ /// Explicit cast to signed 64-bit integer
311
+ CUTLASS_HOST_DEVICE
312
+ explicit operator int64_t() const {
313
+ return int64_t(get());
314
+ }
315
+
316
+ /// Explicit cast to unsigned 64-bit integer
317
+ CUTLASS_HOST_DEVICE
318
+ explicit operator uint64_t() const {
319
+ return uint64_t(get());
320
+ }
321
+
322
+ /// Explicit cast to float
323
+ CUTLASS_HOST_DEVICE
324
+ explicit operator float() const {
325
+ return float(get());
326
+ }
327
+
328
+ /// Explicit cast to double
329
+ CUTLASS_HOST_DEVICE
330
+ explicit operator double() const {
331
+ return double(get());
332
+ }
333
+ };
334
+
335
+ template <
336
+ typename Element_, /// CUTLASS numeric element type.
337
+ typename Storage_ = /// Underlying storage type. Must be able to hold an integer
338
+ /// number of objects of type Element.
339
+
340
+ #if defined(__CUDA_ARCH__) /// Default size depends on width of atomicCas() overloads.
341
+ #if (__CUDA_ARCH__ >= 700) ///
342
+ uint16_t
343
+ #else
344
+ uint32_t
345
+ #endif
346
+ #else
347
+ uint8_t
348
+ #endif
349
+ ,
350
+ class = void
351
+ >
352
+ class SubbyteReference {
353
+ public:
354
+
355
+ using Element = Element_;
356
+ using Storage = Storage_;
357
+ using StoragePointer = Storage *;
358
+
359
+ static_assert(sizeof_bits<Element>::value <= sizeof_bits<Storage>::value,
360
+ "Size of Element must not be greater than Storage.");
361
+
362
+ static_assert(!(sizeof_bits<Storage>::value % sizeof_bits<Element>::value),
363
+ "Storage must be divisible by Element");
364
+
365
+ private:
366
+
367
+ ///! Number of elements per storage vector
368
+ int const kElementsPerVector = sizeof_bits<Storage>::value / sizeof_bits<Element>::value;
369
+
370
+ ///! Bit mask
371
+ Storage const kMask =
372
+ ((sizeof_bits<Element>::value < sizeof_bits<Storage>::value) ?
373
+ (Storage(1) << sizeof_bits<Element>::value) - Storage(1) :
374
+ ~Storage(0));
375
+
376
+ private:
377
+
378
+ /// Pointer to array containing element
379
+ StoragePointer ptr_;
380
+
381
+ /// Offset (in units of elements) from pointer.
382
+ ///
383
+ /// Invariant: must always be in range [0, kElementsPerVector)
384
+ int offset_;
385
+
386
+ public:
387
+
388
+ CUTLASS_HOST_DEVICE
389
+ SubbyteReference(): ptr_(nullptr), offset_(0) { }
390
+
391
+ /// Constructor
392
+ CUTLASS_HOST_DEVICE
393
+ SubbyteReference(
394
+ Element *ptr, /// pointer to memory
395
+ int64_t offset /// logical offset in units of Element
396
+ ):
397
+ ptr_(reinterpret_cast<StoragePointer>(ptr)),
398
+ offset_(0) {
399
+
400
+ int64_t offset_in_vectors = offset / kElementsPerVector;
401
+ int64_t offset_in_elements = offset % kElementsPerVector;
402
+
403
+ ptr_ += offset_in_vectors;
404
+ offset_ = int(offset_in_elements);
405
+ }
406
+
407
+ /// Constructor
408
+ CUTLASS_HOST_DEVICE
409
+ SubbyteReference(
410
+ Element *ptr = nullptr
411
+ ): SubbyteReference(ptr, 0) { }
412
+
413
+ /// Gets storage pointer
414
+ CUTLASS_HOST_DEVICE
415
+ StoragePointer storage_pointer() const {
416
+ return ptr_;
417
+ }
418
+
419
+ /// Gets storage pointer
420
+ CUTLASS_HOST_DEVICE
421
+ Element * operator&() const {
422
+ return reinterpret_cast<Element *>(ptr_);
423
+ }
424
+
425
+ /// Gets element offset within storage vector
426
+ CUTLASS_HOST_DEVICE
427
+ int element_offset() const {
428
+ return offset_;
429
+ }
430
+
431
+ /// Unpacks an element from memory
432
+ CUTLASS_HOST_DEVICE
433
+ Element get() const {
434
+ uint8_t const* byte_ptr = reinterpret_cast<uint8_t const*>(ptr_);
435
+ // Convert offset in elements to offset in bytes
436
+ constexpr int elements_per_byte = cutlass::sizeof_bits<uint8_t>::value / cutlass::sizeof_bits<Element>::value;
437
+ byte_ptr += offset_ / elements_per_byte;
438
+ // Offset of element within a byte
439
+ int byte_offset = offset_ % elements_per_byte;
440
+ uint8_t item = uint8_t((*byte_ptr >> (byte_offset * cutlass::sizeof_bits<Element>::value)) & kMask);
441
+ return reinterpret_cast<Element const &>(item);
442
+ }
443
+
444
+ /// Stores an element to memory
445
+ CUTLASS_HOST_DEVICE
446
+ SubbyteReference & set(Element const &x) {
447
+
448
+ Storage item = (reinterpret_cast<Storage const &>(x) & kMask);
449
+ Storage kUpdateMask = Storage(~(kMask << (offset_ * cutlass::sizeof_bits<Element>::value)));
450
+ Storage new_bits = Storage(item << (offset_ * cutlass::sizeof_bits<Element>::value));
451
+
452
+ #if defined(__CUDA_ARCH__)
453
+
454
+ //
455
+ // Homebrew read-modify-write
456
+ //
457
+ Storage original;
458
+ Storage updated;
459
+
460
+ do {
461
+
462
+ original = (*ptr_);
463
+
464
+ updated = Storage((original & kUpdateMask) | new_bits);
465
+
466
+ original = atomicCAS(ptr_, original, updated);
467
+
468
+ } while (updated != original);
469
+
470
+ #else
471
+
472
+ Storage original = (*ptr_);
473
+ Storage updated = Storage((original & kUpdateMask) | new_bits);
474
+ *ptr_ = updated;
475
+
476
+ #endif
477
+
478
+ return *this;
479
+ }
480
+
481
+ ////
482
+
483
+ /// Unpacks an element from memory
484
+ CUTLASS_HOST_DEVICE
485
+ operator Element() const {
486
+ return get();
487
+ }
488
+
489
+ /// Stores an element to memory
490
+ CUTLASS_HOST_DEVICE
491
+ SubbyteReference &operator=(Element const & x) {
492
+ return set(x);
493
+ }
494
+
495
+ /// Stores an element to memory
496
+ CUTLASS_HOST_DEVICE
497
+ SubbyteReference &operator=(SubbyteReference const & x) {
498
+ return set(x.get());
499
+ }
500
+
501
+ /// Stores an element to memory
502
+ CUTLASS_HOST_DEVICE
503
+ SubbyteReference &operator=(
504
+ ConstSubbyteReference<Element, Storage> const &x) {
505
+ return set(x.get());
506
+ }
507
+
508
+ /// Adds an offset in units of elements to the reference
509
+ CUTLASS_HOST_DEVICE
510
+ SubbyteReference &operator+=(int offset) {
511
+
512
+ offset += offset_;
513
+
514
+ int offset_in_vectors = offset / kElementsPerVector;
515
+ int offset_in_elements = offset % kElementsPerVector;
516
+
517
+ ptr_ += offset_in_vectors;
518
+ offset_ = offset_in_elements;
519
+
520
+ return *this;
521
+ }
522
+
523
+ /// Adds an offset in units of elements to the reference
524
+ CUTLASS_HOST_DEVICE
525
+ SubbyteReference &operator+=(long long offset) {
526
+
527
+ offset += offset_;
528
+
529
+ long long offset_in_vectors = offset / kElementsPerVector;
530
+ int offset_in_elements = int(offset % kElementsPerVector);
531
+
532
+ ptr_ += offset_in_vectors;
533
+ offset_ = offset_in_elements;
534
+
535
+ return *this;
536
+ }
537
+
538
+ /// Adds an offset in units of elements to the reference
539
+ CUTLASS_HOST_DEVICE
540
+ SubbyteReference &operator-=(int offset) {
541
+
542
+ int offset_in_vectors = offset / kElementsPerVector;
543
+ int offset_in_elements = offset % kElementsPerVector;
544
+
545
+ ptr_ -= offset_in_vectors;
546
+ offset_ -= offset_in_elements;
547
+
548
+ if (offset_ < 0) {
549
+ offset_ += kElementsPerVector;
550
+ --ptr_;
551
+ }
552
+
553
+ return *this;
554
+ }
555
+
556
+ /// Adds an offset in units of elements to the reference
557
+ CUTLASS_HOST_DEVICE
558
+ SubbyteReference &operator-=(long long offset) {
559
+
560
+ long long offset_in_vectors = offset / kElementsPerVector;
561
+ int offset_in_elements = int(offset % kElementsPerVector);
562
+
563
+ ptr_ -= offset_in_vectors;
564
+ offset_ -= offset_in_elements;
565
+
566
+ if (offset_ < 0) {
567
+ offset_ += kElementsPerVector;
568
+ --ptr_;
569
+ }
570
+
571
+ return *this;
572
+ }
573
+
574
+ /// Returns a reference to an element with a given offset from the current reference
575
+ CUTLASS_HOST_DEVICE
576
+ SubbyteReference operator+(int offset) const {
577
+
578
+ SubbyteReference ref(ptr_, offset_);
579
+ ref += offset;
580
+
581
+ return ref;
582
+ }
583
+
584
+ /// Returns a reference to an element with a given offset from the current reference
585
+ CUTLASS_HOST_DEVICE
586
+ SubbyteReference operator+(long long offset) const {
587
+
588
+ SubbyteReference ref(ptr_, offset_);
589
+ ref += offset;
590
+
591
+ return ref;
592
+ }
593
+
594
+ /// Returns a reference to an element with a given offset from the current reference
595
+ CUTLASS_HOST_DEVICE
596
+ SubbyteReference operator-(int offset) const {
597
+
598
+ SubbyteReference ref(ptr_, offset_);
599
+ ref -= offset;
600
+
601
+ return ref;
602
+ }
603
+
604
+ /// Returns a reference to an element with a given offset from the current reference
605
+ CUTLASS_HOST_DEVICE
606
+ SubbyteReference operator-=(long long offset) const {
607
+
608
+ SubbyteReference ref(ptr_, offset_);
609
+ ref -= offset;
610
+
611
+ return ref;
612
+ }
613
+
614
+ /// Computes the difference in elements between references
615
+ CUTLASS_HOST_DEVICE
616
+ ptrdiff_t operator-(SubbyteReference ref) const {
617
+ return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_);
618
+ }
619
+
620
+ /// Explicit cast to int
621
+ CUTLASS_HOST_DEVICE
622
+ explicit operator int() const {
623
+ return int(get());
624
+ }
625
+
626
+ /// Explicit cast to signed 64-bit integer
627
+ CUTLASS_HOST_DEVICE
628
+ explicit operator int64_t() const {
629
+ return int64_t(get());
630
+ }
631
+
632
+ /// Explicit cast to unsigned 64-bit integer
633
+ CUTLASS_HOST_DEVICE
634
+ explicit operator uint64_t() const {
635
+ return uint64_t(get());
636
+ }
637
+
638
+ /// Explicit cast to float
639
+ CUTLASS_HOST_DEVICE
640
+ explicit operator float() const {
641
+ return float(get());
642
+ }
643
+
644
+ /// Explicit cast to double
645
+ CUTLASS_HOST_DEVICE
646
+ explicit operator double() const {
647
+ return double(get());
648
+ }
649
+ };
650
+
651
+ /////////////////////////////////////////////////////////////////////////////////////////////////
652
+
653
+ template<typename T> using _war = T;
654
+ template <
655
+ typename Element_, /// CUTLASS numeric element type.
656
+ typename Storage_ /// Underlying basic storage type.
657
+ >
658
+ class SubbyteReference<Element_, Storage_,
659
+ typename platform::enable_if<sizeof_bits<Storage_>::value % sizeof_bits<Element_>::value != 0>::type> {
660
+ public:
661
+
662
+ using Element = Element_;
663
+ /// Note: It's possible that StorageUnit is not divisible by Element.
664
+ /// For example, an Element instance might be stored across 2 StorageUnit instances.
665
+ /// Thus, CUTLASS needs a storage vector to hold an integer number of Element instances.
666
+
667
+ using StorageUnit = Storage_;
668
+ private:
669
+ using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator<Element, StorageUnit>;
670
+ public:
671
+ static int const kBitsStoredVec = StorageContainerCalculator::kContainerTypeNumBits;
672
+ static int const kNumStorageUnitPerStoredVec = StorageContainerCalculator::kContainerTypeNumStorageUnit;
673
+
674
+ using StorageVec = StorageUnit[kNumStorageUnitPerStoredVec];
675
+ using StorageVecPointer = StorageVec *;
676
+
677
+ using CudaAtomicType = typename platform::conditional<
678
+ sizeof_bits<StorageUnit>::value == 16,
679
+ uint32_t,
680
+ uint64_t
681
+ >::type;
682
+
683
+ static_assert(sizeof_bits<Element>::value <= sizeof_bits<StorageVec>::value,
684
+ "Size of Element must not be greater than StorageVec.");
685
+
686
+ static_assert(!(sizeof_bits<StorageVec>::value % sizeof_bits<Element>::value),
687
+ "StorageVec must be divisible by Element");
688
+
689
+ private:
690
+
691
+ ///! Number of elements per storage vector
692
+ int const kElementsPerVector = sizeof_bits<StorageVec>::value / sizeof_bits<Element>::value;
693
+
694
+ ///! Bit mask for storage unit.
695
+ StorageUnit const kMask = (StorageUnit(1) << sizeof_bits<Element>::value) - StorageUnit(1);
696
+
697
+ /// Pointer to array containing element
698
+ _war<StorageVecPointer> ptr_;
699
+
700
+ /// Offset (in units of elements) from pointer.
701
+ ///
702
+ /// Invariant: must always be in range [0, kElementsPerVector)
703
+ int offset_;
704
+
705
+ /// Element may be stored across 2 storage unit.
706
+ /// Low storage unit index in StorageVec
707
+ /// High storage unit index in StorageVec
708
+ int low_storage_unit_idx_;
709
+ int high_storage_unit_idx_;
710
+
711
+ /// Full Mask to extract the entire element
712
+ uint64_t full_element_mask_;
713
+
714
+ /// Mask to extract the Element from Low storage unit and High storage unit.
715
+ StorageUnit low_storage_mask_;
716
+ StorageUnit high_storage_mask_;
717
+
718
+ /// Start bit index inside the storage unit.
719
+ int start_bit_idx_;
720
+
721
+ private:
722
+
723
+ CUTLASS_HOST_DEVICE
724
+ void update_element_status() {
725
+ int num_bits = offset_ * sizeof_bits<Element>::value;
726
+
727
+ start_bit_idx_ = num_bits % sizeof_bits<StorageUnit>::value;
728
+
729
+ low_storage_unit_idx_ = num_bits / sizeof_bits<StorageUnit>::value;
730
+ high_storage_unit_idx_ = sizeof_bits<StorageUnit>::value - (start_bit_idx_) < sizeof_bits<Element>::value
731
+ ? low_storage_unit_idx_ + 1 : low_storage_unit_idx_;
732
+
733
+ full_element_mask_ = uint64_t(kMask) << start_bit_idx_;
734
+ low_storage_mask_ = StorageUnit(full_element_mask_ & ~StorageUnit(0));
735
+ high_storage_mask_ = StorageUnit((full_element_mask_ >> sizeof_bits<StorageUnit>::value) & ~StorageUnit(0));
736
+ }
737
+
738
+ public:
739
+
740
+ CUTLASS_HOST_DEVICE
741
+ SubbyteReference(): ptr_(nullptr), offset_(0) { }
742
+
743
+ /// Constructor
744
+ CUTLASS_HOST_DEVICE
745
+ SubbyteReference(
746
+ Element *ptr, /// pointer to memory
747
+ int64_t offset /// logical offset in units of Element
748
+ ):
749
+ ptr_(reinterpret_cast<StorageVecPointer>(ptr)),
750
+ offset_(0) {
751
+ int64_t offset_in_vectors = offset / kElementsPerVector;
752
+ int64_t offset_in_elements = offset % kElementsPerVector;
753
+
754
+ ptr_ += offset_in_vectors;
755
+ offset_ = int(offset_in_elements);
756
+
757
+ update_element_status();
758
+ }
759
+
760
+ /// Constructor
761
+ CUTLASS_HOST_DEVICE
762
+ SubbyteReference(
763
+ Element *ptr = nullptr
764
+ ): SubbyteReference(ptr, 0) { }
765
+
766
+ /// Gets StorageVec pointer
767
+ CUTLASS_HOST_DEVICE
768
+ StorageVecPointer storage_pointer() const {
769
+ return ptr_;
770
+ }
771
+
772
+ /// Gets StorageVec pointer
773
+ CUTLASS_HOST_DEVICE
774
+ Element * operator&() const {
775
+ return reinterpret_cast<Element *>(ptr_);
776
+ }
777
+
778
+ /// Gets element offset within StorageVec vector
779
+ CUTLASS_HOST_DEVICE
780
+ int element_offset() const {
781
+ return offset_;
782
+ }
783
+
784
+ /// Unpacks an element from memory
785
+ CUTLASS_HOST_DEVICE
786
+ Element get() const {
787
+ StorageUnit low_bits = (*ptr_)[low_storage_unit_idx_] & low_storage_mask_;
788
+ StorageUnit high_bits = low_storage_unit_idx_ != high_storage_unit_idx_ ? (*ptr_)[high_storage_unit_idx_] & high_storage_mask_ : 0;
789
+
790
+ uint64_t full_item = ((uint64_t)high_bits << sizeof_bits<StorageUnit>::value) | low_bits;
791
+ uint8_t result = uint8_t(full_item >> start_bit_idx_);
792
+
793
+ return reinterpret_cast<Element const &>(result);
794
+ }
795
+
796
+ /// Stores an element to memory
797
+ CUTLASS_HOST_DEVICE
798
+ SubbyteReference & set(Element const &x) {
799
+
800
+ uint64_t item = static_cast<uint64_t>((reinterpret_cast<uint8_t const &>(x) & kMask)) << start_bit_idx_;
801
+
802
+ StorageUnit low_new_bits = StorageUnit(item & ~StorageUnit(0));
803
+ StorageUnit high_new_bits = StorageUnit(item >> sizeof_bits<StorageUnit>::value);
804
+
805
+ StorageUnit const kLowUpdateMask = StorageUnit((~full_element_mask_) & (~StorageUnit(0)));
806
+ StorageUnit const kHighUpdateMask = StorageUnit(((~full_element_mask_) >> sizeof_bits<StorageUnit>::value) & (~StorageUnit(0)));
807
+
808
+ #if defined(__CUDA_ARCH__)
809
+ //
810
+ // Homebrew read-modify-write
811
+ //
812
+ if(high_storage_unit_idx_ != low_storage_unit_idx_){
813
+ /// Only need update 2 storage unit at once.
814
+ /// consider misaligned address issue, we need to do atomicCAS twice
815
+ StorageUnit original_low_bits, original_high_bits, update_low_bits, update_high_bits;
816
+ do {
817
+ original_low_bits = ((*ptr_)[low_storage_unit_idx_]);
818
+ update_low_bits = (original_low_bits & kLowUpdateMask) | low_new_bits;
819
+ original_low_bits = atomicCAS(&((*ptr_)[low_storage_unit_idx_]), original_low_bits, update_low_bits);
820
+ } while (update_low_bits != original_low_bits);
821
+ do {
822
+ original_high_bits = ((*ptr_)[high_storage_unit_idx_]);
823
+ update_high_bits = (original_high_bits & kHighUpdateMask) | high_new_bits;
824
+ original_high_bits = atomicCAS(&((*ptr_)[high_storage_unit_idx_]), original_high_bits, update_high_bits);
825
+ } while (update_high_bits != original_high_bits);
826
+ }
827
+ else {
828
+ /// Only need update 1 storage unit.
829
+ StorageUnit original, updated;
830
+ do {
831
+ original = ((*ptr_)[low_storage_unit_idx_]);
832
+
833
+ updated = (original & kLowUpdateMask) | low_new_bits;
834
+
835
+ original = atomicCAS(&((*ptr_)[low_storage_unit_idx_]), original, updated);
836
+
837
+ } while (updated != original);
838
+ }
839
+ #else
840
+
841
+
842
+ StorageUnit update_low_bits = ((*ptr_)[low_storage_unit_idx_] & kLowUpdateMask) | low_new_bits;
843
+ StorageUnit update_high_bits = ((*ptr_)[high_storage_unit_idx_] & kHighUpdateMask) | high_new_bits;
844
+
845
+ (*ptr_)[low_storage_unit_idx_] = update_low_bits;
846
+
847
+ if(low_storage_unit_idx_ != high_storage_unit_idx_)
848
+ (*ptr_)[high_storage_unit_idx_] = update_high_bits;
849
+ #endif
850
+
851
+ return *this;
852
+ }
853
+
854
+ ////
855
+
856
+ /// Unpacks an element from memory
857
+ CUTLASS_HOST_DEVICE
858
+ operator Element() const {
859
+ return get();
860
+ }
861
+
862
+ /// Stores an element to memory
863
+ CUTLASS_HOST_DEVICE
864
+ SubbyteReference &operator=(Element const & x) {
865
+ return set(x);
866
+ }
867
+
868
+ /// Stores an element to memory
869
+ CUTLASS_HOST_DEVICE
870
+ SubbyteReference &operator=(SubbyteReference const & x) {
871
+ return set(x.get());
872
+ }
873
+
874
+ /// Stores an element to memory
875
+ CUTLASS_HOST_DEVICE
876
+ SubbyteReference &operator=(
877
+ ConstSubbyteReference<Element, StorageVec> const &x) {
878
+ return set(x.get());
879
+ }
880
+
881
+ /// Adds an offset in units of elements to the reference
882
+ CUTLASS_HOST_DEVICE
883
+ SubbyteReference &operator+=(int offset) {
884
+
885
+ offset += offset_;
886
+
887
+ int offset_in_vectors = offset / kElementsPerVector;
888
+ int offset_in_elements = offset % kElementsPerVector;
889
+
890
+ ptr_ += offset_in_vectors;
891
+ offset_ = offset_in_elements;
892
+
893
+ update_element_status();
894
+
895
+ return *this;
896
+ }
897
+
898
+ /// Adds an offset in units of elements to the reference
899
+ CUTLASS_HOST_DEVICE
900
+ SubbyteReference &operator+=(long long offset) {
901
+
902
+ offset += offset_;
903
+
904
+ long long offset_in_vectors = offset / kElementsPerVector;
905
+ int offset_in_elements = int(offset % kElementsPerVector);
906
+
907
+ ptr_ += offset_in_vectors;
908
+ offset_ = offset_in_elements;
909
+
910
+ update_element_status();
911
+
912
+ return *this;
913
+ }
914
+
915
+ /// Adds an offset in units of elements to the reference
916
+ CUTLASS_HOST_DEVICE
917
+ SubbyteReference &operator-=(int offset) {
918
+
919
+ int offset_in_vectors = offset / kElementsPerVector;
920
+ int offset_in_elements = offset % kElementsPerVector;
921
+
922
+ ptr_ -= offset_in_vectors;
923
+ offset_ -= offset_in_elements;
924
+
925
+ if (offset_ < 0) {
926
+ offset_ += kElementsPerVector;
927
+ --ptr_;
928
+ }
929
+
930
+ update_element_status();
931
+ return *this;
932
+ }
933
+
934
+ /// Adds an offset in units of elements to the reference
935
+ CUTLASS_HOST_DEVICE
936
+ SubbyteReference &operator-=(long long offset) {
937
+
938
+ long long offset_in_vectors = offset / kElementsPerVector;
939
+ int offset_in_elements = int(offset % kElementsPerVector);
940
+
941
+ ptr_ -= offset_in_vectors;
942
+ offset_ -= offset_in_elements;
943
+
944
+ if (offset_ < 0) {
945
+ offset_ += kElementsPerVector;
946
+ --ptr_;
947
+ }
948
+
949
+ update_element_status();
950
+ return *this;
951
+ }
952
+
953
+ /// Returns a reference to an element with a given offset from the current reference
954
+ CUTLASS_HOST_DEVICE
955
+ SubbyteReference operator+(int offset) const {
956
+
957
+ SubbyteReference ref(ptr_, offset_);
958
+ ref += offset;
959
+
960
+ return ref;
961
+ }
962
+
963
+ /// Returns a reference to an element with a given offset from the current reference
964
+ CUTLASS_HOST_DEVICE
965
+ SubbyteReference operator+(long long offset) const {
966
+
967
+ SubbyteReference ref(ptr_, offset_);
968
+ ref += offset;
969
+
970
+ return ref;
971
+ }
972
+
973
+ /// Returns a reference to an element with a given offset from the current reference
974
+ CUTLASS_HOST_DEVICE
975
+ SubbyteReference operator-(int offset) const {
976
+
977
+ SubbyteReference ref(ptr_, offset_);
978
+ ref -= offset;
979
+
980
+ return ref;
981
+ }
982
+
983
+ /// Returns a reference to an element with a given offset from the current reference
984
+ CUTLASS_HOST_DEVICE
985
+ SubbyteReference operator-=(long long offset) const {
986
+
987
+ SubbyteReference ref(ptr_, offset_);
988
+ ref -= offset;
989
+
990
+ return ref;
991
+ }
992
+
993
+ /// Computes the difference in elements between references
994
+ CUTLASS_HOST_DEVICE
995
+ ptrdiff_t operator-(SubbyteReference ref) const {
996
+ return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_);
997
+ }
998
+
999
+ /// Explicit cast to int
1000
+ CUTLASS_HOST_DEVICE
1001
+ explicit operator int() const {
1002
+ return int(get());
1003
+ }
1004
+
1005
+ /// Explicit cast to signed 64-bit integer
1006
+ CUTLASS_HOST_DEVICE
1007
+ explicit operator int64_t() const {
1008
+ return int64_t(get());
1009
+ }
1010
+
1011
+ /// Explicit cast to unsigned 64-bit integer
1012
+ CUTLASS_HOST_DEVICE
1013
+ explicit operator uint64_t() const {
1014
+ return uint64_t(get());
1015
+ }
1016
+
1017
+ /// Explicit cast to float
1018
+ CUTLASS_HOST_DEVICE
1019
+ explicit operator float() const {
1020
+ return float(get());
1021
+ }
1022
+
1023
+ /// Explicit cast to double
1024
+ CUTLASS_HOST_DEVICE
1025
+ explicit operator double() const {
1026
+ return double(get());
1027
+ }
1028
+ };
1029
+
1030
+ template<typename T> using _war = T;
1031
+ template <
1032
+ typename Element_, /// CUTLASS numeric element type.
1033
+ typename Storage_ /// Underlying storage type. Must be able to hold an integer
1034
+ >
1035
+ class ConstSubbyteReference<Element_, Storage_,
1036
+ typename platform::enable_if<sizeof_bits<Storage_>::value % sizeof_bits<Element_>::value != 0>::type> {
1037
+ public:
1038
+
1039
+ using Element = Element_;
1040
+ ///! Note: Storage unit could not be divisibale by Element,
1041
+ /// Type element may be stored across 2 storage units, so need a storage vector to hold integer
1042
+ /// number of objects of type Element.
1043
+ using StorageUnit = Storage_;
1044
+ static int const kBitsStoredVec = cutlass::lcm_cxx11(sizeof_bits<Element>::value, sizeof_bits<StorageUnit>::value);
1045
+ static int const kNumStorageUnitPerStoredVec = kBitsStoredVec / sizeof_bits<StorageUnit>::value;
1046
+
1047
+ using StorageVec = StorageUnit[kNumStorageUnitPerStoredVec];
1048
+ using StorageVecPointer = StorageVec const *;
1049
+
1050
+ using CudaAtomicType = typename platform::conditional<
1051
+ sizeof_bits<StorageUnit>::value == 16,
1052
+ uint32_t,
1053
+ uint64_t
1054
+ >::type;
1055
+
1056
+ static_assert(sizeof_bits<Element>::value <= sizeof_bits<StorageVec>::value,
1057
+ "Size of Element must not be greater than StorageVec.");
1058
+
1059
+ static_assert(!(sizeof_bits<StorageVec>::value % sizeof_bits<Element>::value),
1060
+ "StorageVec must be divisible by Element");
1061
+
1062
+ private:
1063
+
1064
+ ///! Number of elements per storage vector
1065
+ int const kElementsPerVector = sizeof_bits<StorageVec>::value / sizeof_bits<Element>::value;
1066
+
1067
+ ///! Bit mask for storage unit.
1068
+ StorageUnit const kMask = (StorageUnit(1) << sizeof_bits<Element>::value) - StorageUnit(1);
1069
+
1070
+ /// Pointer to array containing element
1071
+ _war<StorageVecPointer> ptr_;
1072
+
1073
+ /// Offset (in units of elements) from pointer.
1074
+ ///
1075
+ /// Invariant: must always be in range [0, kElementsPerVector)
1076
+ int offset_;
1077
+
1078
+ /// Element may be stored across 2 storage unit.
1079
+ /// Low storage unit index in StorageVec
1080
+ /// High storage unit index in StorageVec
1081
+ int low_storage_unit_idx_;
1082
+ int high_storage_unit_idx_;
1083
+
1084
+ /// Full Mask to extract the entire element
1085
+ uint64_t full_element_mask_;
1086
+
1087
+ /// Mask to extract the Element from Low storage unit and High storage unit.
1088
+ StorageUnit low_storage_mask_;
1089
+ StorageUnit high_storage_mask_;
1090
+
1091
+ /// Start bit index inside the storage unit.
1092
+ int start_bit_idx_;
1093
+
1094
+ private:
1095
+
1096
+ CUTLASS_HOST_DEVICE
1097
+ void update_element_status() {
1098
+ int num_bits = offset_ * sizeof_bits<Element>::value;
1099
+
1100
+ start_bit_idx_ = num_bits % sizeof_bits<StorageUnit>::value;
1101
+
1102
+ low_storage_unit_idx_ = num_bits / sizeof_bits<StorageUnit>::value;
1103
+ high_storage_unit_idx_ = sizeof_bits<StorageUnit>::value - (start_bit_idx_) < sizeof_bits<Element>::value
1104
+ ? low_storage_unit_idx_ + 1 : low_storage_unit_idx_;
1105
+
1106
+ full_element_mask_ = uint64_t(kMask) << start_bit_idx_;
1107
+ low_storage_mask_ = StorageUnit(full_element_mask_ & ~StorageUnit(0));
1108
+ high_storage_mask_ = StorageUnit((full_element_mask_ >> sizeof_bits<StorageUnit>::value) & ~StorageUnit(0));
1109
+ }
1110
+
1111
+ public:
1112
+
1113
+ CUTLASS_HOST_DEVICE
1114
+ ConstSubbyteReference(): ptr_(nullptr), offset_(0) { }
1115
+
1116
+ /// Constructor
1117
+ CUTLASS_HOST_DEVICE
1118
+ ConstSubbyteReference(
1119
+ Element const *ptr, /// pointer to memory
1120
+ int64_t offset /// logical offset in units of Element
1121
+ ):
1122
+ ptr_(reinterpret_cast<StorageVecPointer>(ptr)),
1123
+ offset_(0) {
1124
+
1125
+ int64_t offset_in_vectors = offset / kElementsPerVector;
1126
+ int64_t offset_in_elements = offset % kElementsPerVector;
1127
+
1128
+ ptr_ += offset_in_vectors;
1129
+ offset_ = int(offset_in_elements);
1130
+
1131
+ update_element_status();
1132
+ }
1133
+
1134
+ /// Constructor
1135
+ CUTLASS_HOST_DEVICE
1136
+ ConstSubbyteReference(
1137
+ Element *ptr = nullptr
1138
+ ): ConstSubbyteReference(ptr, 0) { }
1139
+
1140
+ /// Gets storage pointer
1141
+ CUTLASS_HOST_DEVICE
1142
+ StorageVecPointer storage_pointer() const {
1143
+ return ptr_;
1144
+ }
1145
+
1146
+ /// Gets element offset within storage vector
1147
+ CUTLASS_HOST_DEVICE
1148
+ int element_offset() const {
1149
+ return offset_;
1150
+ }
1151
+
1152
+ /// Unpacks an element from memory
1153
+ CUTLASS_HOST_DEVICE
1154
+ Element get() const {
1155
+ StorageUnit low_bits = (*ptr_)[low_storage_unit_idx_] & low_storage_mask_;
1156
+ StorageUnit high_bits = low_storage_unit_idx_ != high_storage_unit_idx_ ? (*ptr_)[high_storage_unit_idx_] & high_storage_mask_ : 0;
1157
+
1158
+ uint64_t full_item = ((uint64_t)high_bits << sizeof_bits<StorageUnit>::value) | low_bits;
1159
+ uint8_t result = uint8_t(full_item >> start_bit_idx_);
1160
+
1161
+ return reinterpret_cast<Element const &>(result);
1162
+ }
1163
+
1164
+ /// Unpacks an element from memory
1165
+ CUTLASS_HOST_DEVICE
1166
+ operator Element() const {
1167
+ return get();
1168
+ }
1169
+
1170
+ /// Adds an offset in units of elements to the reference
1171
+ CUTLASS_HOST_DEVICE
1172
+ ConstSubbyteReference &operator+=(int offset) {
1173
+
1174
+ offset += offset_;
1175
+
1176
+ int offset_in_vectors = offset / kElementsPerVector;
1177
+ int offset_in_elements = offset % kElementsPerVector;
1178
+
1179
+ ptr_ += offset_in_vectors;
1180
+ offset_ = offset_in_elements;
1181
+
1182
+ update_element_status();
1183
+
1184
+ return *this;
1185
+ }
1186
+
1187
+ /// Adds an offset in units of elements to the reference
1188
+ CUTLASS_HOST_DEVICE
1189
+ ConstSubbyteReference &operator+=(long long offset) {
1190
+
1191
+ offset += offset_;
1192
+
1193
+ long long offset_in_vectors = offset / kElementsPerVector;
1194
+ int offset_in_elements = int(offset % kElementsPerVector);
1195
+
1196
+ ptr_ += offset_in_vectors;
1197
+ offset_ = offset_in_elements;
1198
+
1199
+ update_element_status();
1200
+
1201
+ return *this;
1202
+ }
1203
+
1204
+ /// Adds an offset in units of elements to the reference
1205
+ CUTLASS_HOST_DEVICE
1206
+ ConstSubbyteReference &operator-=(int offset) {
1207
+
1208
+ int offset_in_vectors = offset / kElementsPerVector;
1209
+ int offset_in_elements = offset % kElementsPerVector;
1210
+
1211
+ ptr_ -= offset_in_vectors;
1212
+ offset_ -= offset_in_elements;
1213
+
1214
+ if (offset_ < 0) {
1215
+ offset_ += kElementsPerVector;
1216
+ --ptr_;
1217
+ }
1218
+
1219
+ update_element_status();
1220
+
1221
+ return *this;
1222
+ }
1223
+
1224
+ /// Adds an offset in units of elements to the reference
1225
+ CUTLASS_HOST_DEVICE
1226
+ ConstSubbyteReference &operator-=(long long offset) {
1227
+
1228
+ long long offset_in_vectors = offset / kElementsPerVector;
1229
+ int offset_in_elements = int(offset % kElementsPerVector);
1230
+
1231
+ ptr_ -= offset_in_vectors;
1232
+ offset_ -= offset_in_elements;
1233
+
1234
+ if (offset_ < 0) {
1235
+ offset_ += kElementsPerVector;
1236
+ --ptr_;
1237
+ }
1238
+
1239
+ update_element_status();
1240
+
1241
+ return *this;
1242
+ }
1243
+
1244
+ /// Returns a reference to an element with a given offset from the current reference
1245
+ CUTLASS_HOST_DEVICE
1246
+ ConstSubbyteReference operator+(int offset) const {
1247
+
1248
+ ConstSubbyteReference ref(ptr_, offset_);
1249
+ ref += offset;
1250
+
1251
+ return ref;
1252
+ }
1253
+
1254
+ /// Returns a reference to an element with a given offset from the current reference
1255
+ CUTLASS_HOST_DEVICE
1256
+ ConstSubbyteReference operator+(long long offset) const {
1257
+
1258
+ ConstSubbyteReference ref(ptr_, offset_);
1259
+ ref += offset;
1260
+
1261
+ return ref;
1262
+ }
1263
+
1264
+ /// Returns a reference to an element with a given offset from the current reference
1265
+ CUTLASS_HOST_DEVICE
1266
+ ConstSubbyteReference operator-(int offset) const {
1267
+
1268
+ ConstSubbyteReference ref(ptr_, offset_);
1269
+ ref -= offset;
1270
+
1271
+ return ref;
1272
+ }
1273
+
1274
+ /// Returns a reference to an element with a given offset from the current reference
1275
+ CUTLASS_HOST_DEVICE
1276
+ ConstSubbyteReference operator-=(long long offset) const {
1277
+
1278
+ ConstSubbyteReference ref(ptr_, offset_);
1279
+ ref -= offset;
1280
+
1281
+ return ref;
1282
+ }
1283
+
1284
+ /// Computes the difference in elements between references
1285
+ CUTLASS_HOST_DEVICE
1286
+ ptrdiff_t operator-(ConstSubbyteReference ref) const {
1287
+ return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_);
1288
+ }
1289
+
1290
+ /// Explicit cast to int
1291
+ CUTLASS_HOST_DEVICE
1292
+ explicit operator int() const {
1293
+ return int(get());
1294
+ }
1295
+
1296
+ /// Explicit cast to signed 64-bit integer
1297
+ CUTLASS_HOST_DEVICE
1298
+ explicit operator int64_t() const {
1299
+ return int64_t(get());
1300
+ }
1301
+
1302
+ /// Explicit cast to unsigned 64-bit integer
1303
+ CUTLASS_HOST_DEVICE
1304
+ explicit operator uint64_t() const {
1305
+ return uint64_t(get());
1306
+ }
1307
+
1308
+ /// Explicit cast to float
1309
+ CUTLASS_HOST_DEVICE
1310
+ explicit operator float() const {
1311
+ return float(get());
1312
+ }
1313
+
1314
+ /// Explicit cast to double
1315
+ CUTLASS_HOST_DEVICE
1316
+ explicit operator double() const {
1317
+ return double(get());
1318
+ }
1319
+ };
1320
+
1321
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1322
+
1323
+ template <typename Element, bool subbyte = (sizeof_bits<Element>::value < 8)>
1324
+ struct ReferenceFactory;
1325
+
1326
+ template <typename Element>
1327
+ struct ReferenceFactory<Element, false> {
1328
+
1329
+ ///! Number of elements per storage vector
1330
+ static int const kElementsPerVector = 1;
1331
+
1332
+ CUTLASS_HOST_DEVICE
1333
+ static Element &get(Element *ptr, int64_t offset) {
1334
+ return ptr[offset];
1335
+ }
1336
+
1337
+ CUTLASS_HOST_DEVICE
1338
+ static Element const &get(Element const *ptr, int64_t offset) {
1339
+ return ptr[offset];
1340
+ }
1341
+
1342
+ CUTLASS_HOST_DEVICE
1343
+ static Element *add_pointer_offset(Element *ptr, int64_t offset) {
1344
+ return ptr + offset;
1345
+ }
1346
+
1347
+ CUTLASS_HOST_DEVICE
1348
+ static Element const *add_pointer_offset(Element const *ptr, int64_t offset) {
1349
+ return ptr + offset;
1350
+ }
1351
+ };
1352
+
1353
+ template <typename Element>
1354
+ struct ReferenceFactory<Element, true> {
1355
+
1356
+ //
1357
+ // Static methods
1358
+ //
1359
+
1360
+ CUTLASS_HOST_DEVICE
1361
+ static SubbyteReference<Element> get(Element *ptr, int64_t offset) {
1362
+ return SubbyteReference<Element>(ptr, offset);
1363
+ }
1364
+
1365
+ CUTLASS_HOST_DEVICE
1366
+ static ConstSubbyteReference<Element> get(Element const *ptr,
1367
+ int64_t offset) {
1368
+ return ConstSubbyteReference<Element>(ptr, offset);
1369
+ }
1370
+
1371
+ /// Helper to add an offset in number of elements, assuming this offset is divisible
1372
+ /// by the vector size.
1373
+ CUTLASS_HOST_DEVICE
1374
+ static Element *add_pointer_offset(Element *ptr, int64_t offset_in_elements) {
1375
+ return &SubbyteReference<Element>(ptr, offset_in_elements);
1376
+ }
1377
+
1378
+ /// Helper to add an offset in number of elements, assuming this offset is divisible
1379
+ /// by the vector size.
1380
+ CUTLASS_HOST_DEVICE
1381
+ static Element const *add_pointer_offset(Element const *ptr, int64_t offset_in_elements) {
1382
+ return &ConstSubbyteReference<Element>(ptr, offset_in_elements);
1383
+ }
1384
+ };
1385
+
1386
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1387
+
1388
+ } // namespace cutlass
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_coord.h ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Defines a canonical coordinate for rank=4 tensors offering named indices.
33
+ */
34
+ #pragma once
35
+
36
+ #include "cutlass/cutlass.h"
37
+ #include "cutlass/coord.h"
38
+
39
+ namespace cutlass {
40
+
41
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
42
+
43
+ /// Defines a canonical 4D coordinate used by tensor operations.
44
+ struct Tensor4DCoord : public Coord<4> {
45
+
46
+ /// Base class
47
+ using Base = Coord<4>;
48
+
49
+ /// Index type
50
+ using Index = typename Base::Index;
51
+
52
+ /// LongIndex type
53
+ using LongIndex = typename Base::LongIndex;
54
+
55
+ /// Batch dimension
56
+ static int const kN = 0;
57
+
58
+ /// Height dimension
59
+ static int const kH = 1;
60
+
61
+ /// Width dimension
62
+ static int const kW = 2;
63
+
64
+ /// Channels dimension
65
+ static int const kC = 3;
66
+
67
+ //
68
+ // Methods
69
+ //
70
+
71
+ /// Default ctor
72
+ CUTLASS_HOST_DEVICE
73
+ Tensor4DCoord() { }
74
+
75
+ /// Constructs from Coord<4>
76
+ CUTLASS_HOST_DEVICE
77
+ Tensor4DCoord(Coord<4> const &coord): Base(coord) { }
78
+
79
+ /// Helper to construct from N, H, W, and C.
80
+ CUTLASS_HOST_DEVICE
81
+ Tensor4DCoord(Index n, Index h, Index w, Index c): Base(make_Coord(n, h, w, c)) { }
82
+
83
+ /// Helper to construct from N, H, W, and C, which are LongIndex type
84
+ CUTLASS_HOST_DEVICE
85
+ Tensor4DCoord(LongIndex n, LongIndex h, LongIndex w, LongIndex c)
86
+ : Base(make_Coord(Index(n), Index(h), Index(w), Index(c))) { }
87
+
88
+ /// Returns the batch of the coordinate
89
+ CUTLASS_HOST_DEVICE
90
+ Index const & n() const { return this->at(kN); }
91
+
92
+ /// Returns the batch of the coordinate
93
+ CUTLASS_HOST_DEVICE
94
+ Index & n() { return this->at(kN); }
95
+
96
+ /// Returns the row of the coordinate
97
+ CUTLASS_HOST_DEVICE
98
+ Index const & h() const { return this->at(kH); }
99
+
100
+ /// Returns the row of the coordinate
101
+ CUTLASS_HOST_DEVICE
102
+ Index & h() { return this->at(kH); }
103
+
104
+ /// Returns the column of the coordinate
105
+ CUTLASS_HOST_DEVICE
106
+ Index const & w() const { return this->at(kW); }
107
+
108
+ /// Returns the column of the coordinate
109
+ CUTLASS_HOST_DEVICE
110
+ Index & w() { return this->at(kW); }
111
+
112
+ /// Returns the channel of the coordinate
113
+ CUTLASS_HOST_DEVICE
114
+ Index const & c() const { return this->at(kC); }
115
+
116
+ /// Returns the channel of the coordinate
117
+ CUTLASS_HOST_DEVICE
118
+ Index & c() { return this->at(kC); }
119
+
120
+ //
121
+ // Coord operators
122
+ //
123
+
124
+ /// Element-wise addition
125
+ CUTLASS_HOST_DEVICE
126
+ Tensor4DCoord operator+(Base const& b) const {
127
+ return Tensor4DCoord(Base::operator+(b));
128
+ }
129
+
130
+ /// Element-wise subtraction
131
+ CUTLASS_HOST_DEVICE
132
+ Tensor4DCoord operator-(Base const& b) const {
133
+ return Tensor4DCoord(Base::operator-(b));
134
+ }
135
+
136
+ /// Element-wise multiplication
137
+ CUTLASS_HOST_DEVICE
138
+ Tensor4DCoord operator*(Base const& b) const {
139
+ return Tensor4DCoord(Base::operator*(b));
140
+ }
141
+
142
+ /// Element-wise division
143
+ CUTLASS_HOST_DEVICE
144
+ Tensor4DCoord operator/(Base const& b) const {
145
+ return Tensor4DCoord(Base::operator/(b));
146
+ }
147
+
148
+ /// In-place addition
149
+ CUTLASS_HOST_DEVICE
150
+ Tensor4DCoord& operator+=(Base const& b) {
151
+ Base::operator+=(b);
152
+ return *this;
153
+ }
154
+
155
+ /// In-place subtraction
156
+ CUTLASS_HOST_DEVICE
157
+ Tensor4DCoord& operator-=(Base const& b) {
158
+ Base::operator-=(b);
159
+ return *this;
160
+ }
161
+
162
+ /// In-place multiplication
163
+ CUTLASS_HOST_DEVICE
164
+ Tensor4DCoord& operator*=(Base const& b) {
165
+ Base::operator*=(b);
166
+ return *this;
167
+ }
168
+
169
+ /// In-place division
170
+ CUTLASS_HOST_DEVICE
171
+ Tensor4DCoord& operator/=(Base const& b) {
172
+ Base::operator/=(b);
173
+ return *this;
174
+ }
175
+ };
176
+
177
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
178
+
179
+ /// Defines a canonical 5D coordinate used by tensor operations.
180
+ struct Tensor5DCoord : public Coord<5> {
181
+
182
+ /// Base class
183
+ using Base = Coord<5>;
184
+
185
+ /// Index type
186
+ using Index = typename Base::Index;
187
+
188
+ /// LongIndex type
189
+ using LongIndex = typename Base::LongIndex;
190
+
191
+ /// Batch dimension
192
+ static int const kN = 0;
193
+
194
+ /// Depth dimension
195
+ static int const kD = 1;
196
+
197
+ /// Height dimension
198
+ static int const kH = 2;
199
+
200
+ /// Width dimension
201
+ static int const kW = 3;
202
+
203
+ /// Channels dimension
204
+ static int const kC = 4;
205
+
206
+ //
207
+ // Methods
208
+ //
209
+
210
+ /// Default ctor
211
+ CUTLASS_HOST_DEVICE
212
+ Tensor5DCoord() { }
213
+
214
+ /// Constructs from Coord<5>
215
+ CUTLASS_HOST_DEVICE
216
+ Tensor5DCoord(Coord<5> const &coord): Base(coord) { }
217
+
218
+ /// Helper to construct from N, D, H, W, and C.
219
+ CUTLASS_HOST_DEVICE
220
+ Tensor5DCoord(Index n, Index d, Index h, Index w, Index c): Base(make_Coord(n, d, h, w, c)) { }
221
+
222
+ /// Helper to construct from N, D, H, W, and C, which are LongIndex type
223
+ CUTLASS_HOST_DEVICE
224
+ Tensor5DCoord(LongIndex n, LongIndex d, LongIndex h, LongIndex w, LongIndex c)
225
+ : Base(make_Coord(Index(n), Index(d), Index(h), Index(w), Index(c))) { }
226
+
227
+ /// Returns the batch of the coordinate
228
+ CUTLASS_HOST_DEVICE
229
+ Index const & n() const { return this->at(kN); }
230
+
231
+ /// Returns the batch of the coordinate
232
+ CUTLASS_HOST_DEVICE
233
+ Index & n() { return this->at(kN); }
234
+
235
+ /// Returns the batch of the coordinate
236
+ CUTLASS_HOST_DEVICE
237
+ Index const & d() const { return this->at(kD); }
238
+
239
+ /// Returns the batch of the coordinate
240
+ CUTLASS_HOST_DEVICE
241
+ Index & d() { return this->at(kD); }
242
+
243
+ /// Returns the row of the coordinate
244
+ CUTLASS_HOST_DEVICE
245
+ Index const & h() const { return this->at(kH); }
246
+
247
+ /// Returns the row of the coordinate
248
+ CUTLASS_HOST_DEVICE
249
+ Index & h() { return this->at(kH); }
250
+
251
+ /// Returns the column of the coordinate
252
+ CUTLASS_HOST_DEVICE
253
+ Index const & w() const { return this->at(kW); }
254
+
255
+ /// Returns the column of the coordinate
256
+ CUTLASS_HOST_DEVICE
257
+ Index & w() { return this->at(kW); }
258
+
259
+ /// Returns the channel of the coordinate
260
+ CUTLASS_HOST_DEVICE
261
+ Index const & c() const { return this->at(kC); }
262
+
263
+ /// Returns the channel of the coordinate
264
+ CUTLASS_HOST_DEVICE
265
+ Index & c() { return this->at(kC); }
266
+
267
+ //
268
+ // Coord operators
269
+ //
270
+
271
+ /// Element-wise addition
272
+ CUTLASS_HOST_DEVICE
273
+ Tensor5DCoord operator+(Base const& b) const {
274
+ return Tensor5DCoord(Base::operator+(b));
275
+ }
276
+
277
+ /// Element-wise subtraction
278
+ CUTLASS_HOST_DEVICE
279
+ Tensor5DCoord operator-(Base const& b) const {
280
+ return Tensor5DCoord(Base::operator-(b));
281
+ }
282
+
283
+ /// Element-wise multiplication
284
+ CUTLASS_HOST_DEVICE
285
+ Tensor5DCoord operator*(Base const& b) const {
286
+ return Tensor5DCoord(Base::operator*(b));
287
+ }
288
+
289
+ /// Element-wise division
290
+ CUTLASS_HOST_DEVICE
291
+ Tensor5DCoord operator/(Base const& b) const {
292
+ return Tensor5DCoord(Base::operator/(b));
293
+ }
294
+
295
+ /// In-place addition
296
+ CUTLASS_HOST_DEVICE
297
+ Tensor5DCoord& operator+=(Base const& b) {
298
+ Base::operator+=(b);
299
+ return *this;
300
+ }
301
+
302
+ /// In-place subtraction
303
+ CUTLASS_HOST_DEVICE
304
+ Tensor5DCoord& operator-=(Base const& b) {
305
+ Base::operator-=(b);
306
+ return *this;
307
+ }
308
+
309
+ /// In-place multiplication
310
+ CUTLASS_HOST_DEVICE
311
+ Tensor5DCoord& operator*=(Base const& b) {
312
+ Base::operator*=(b);
313
+ return *this;
314
+ }
315
+
316
+ /// In-place division
317
+ CUTLASS_HOST_DEVICE
318
+ Tensor5DCoord& operator/=(Base const& b) {
319
+ Base::operator/=(b);
320
+ return *this;
321
+ }
322
+ };
323
+
324
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
325
+
326
+ } // namespace cutlass
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_ref.h ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Defines a structure containing strides, bounds, and a pointer to tensor data.
33
+ */
34
+ #pragma once
35
+
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/coord.h"
39
+ #include "cutlass/platform/platform.h"
40
+ #include "cutlass/subbyte_reference.h"
41
+
42
+ namespace cutlass {
43
+
44
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
45
+
46
+ /// Default layout function from coordinates in a tensor's index space into the n-D array held
47
+ /// in memory.
48
+ ///
49
+ /// All layout functions must define at least the members shown in IdentityTensorLayout<>.
50
+ template <int Rank>
51
+ class IdentityTensorLayout {
52
+ public:
53
+ /// Logical rank of tensor
54
+ static int const kRank = Rank;
55
+
56
+ /// Rank of stride vector
57
+ static int const kStrideRank = Rank;
58
+
59
+ /// Index type used for coordinates
60
+ using Index = int32_t;
61
+
62
+ /// Long index type used for offsets
63
+ using LongIndex = int64_t;
64
+
65
+ /// Logical coordinate
66
+ using TensorCoord = Coord<kRank, Index>;
67
+
68
+ /// Stride vector
69
+ using Stride = Coord<kStrideRank, Index>;
70
+
71
+ private:
72
+
73
+ //
74
+ // Data members
75
+ //
76
+
77
+ /// Stride data member
78
+ Stride stride_;
79
+
80
+ public:
81
+
82
+ //
83
+ // Methods
84
+ //
85
+
86
+ CUTLASS_HOST_DEVICE
87
+ IdentityTensorLayout(Stride const &stride = Stride()): stride_(stride) { }
88
+
89
+ /// Returns the offset of a coordinate in linear memory
90
+ CUTLASS_HOST_DEVICE
91
+ LongIndex operator()(Coord<Rank> const &coord) const {
92
+ return coord.dot(stride_);
93
+ }
94
+
95
+ /// Returns the stride of the layout
96
+ CUTLASS_HOST_DEVICE
97
+ Stride stride() const {
98
+ return stride_;
99
+ }
100
+
101
+ /// Returns the stride of the layout
102
+ CUTLASS_HOST_DEVICE
103
+ Stride & stride() {
104
+ return stride_;
105
+ }
106
+
107
+ /// Compute the number of contiguous elements needed to store a tensor with the given size
108
+ CUTLASS_HOST_DEVICE
109
+ LongIndex capacity(TensorCoord const &size) const {
110
+ int idx = stride_.max_dim_index();
111
+ return stride_[idx] * size[idx];
112
+ }
113
+ };
114
+
115
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
116
+
117
+ /* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank
118
+ and layout within memory. A TensorRef combines a pointer and a Layout concept
119
+
120
+ Examples:
121
+
122
+ (These examples use helpers for matrix layouts defined in cutlass/layout/matrix.h)
123
+
124
+ 1. Column-major matrix may be represented as a rank=2 tensor:
125
+
126
+ TensorRef<float, layout::ColumnMajor> A(ptr_A, ldm);
127
+
128
+ 2. Row-major matrix may be represented as a rank=2 tensor:
129
+
130
+ TensorRef<float, layout::RowMajor> B(ptr_A, ldm);
131
+
132
+ 3. An interleaved matrix may be represented as a rank=2 tensor:
133
+
134
+ TensorRef<int8_t, layout::ColumnMajorInterleaved<32> > C;
135
+
136
+ 4. A helper exists to define a TensorRef for a contiguous matrix whose layout
137
+ is not known at compile time.
138
+
139
+ int ldm; // leading dimension
140
+ layout::Matrix kind; // Could be layout::Matrix::kRowMajor or layout::Matrix::kColumnMajor
141
+
142
+
143
+ TensorRef<int, layout::ContiguousMatrix> E(ptr_E, {ldm, kind});
144
+
145
+ */
146
+ template <
147
+ /// Data type of element stored within tensor (concept: NumericType)
148
+ typename Element_,
149
+ /// Defines a mapping from logical coordinate to linear memory (concept: Layout)
150
+ typename Layout_
151
+ >
152
+ class TensorRef {
153
+ public:
154
+ /// Data type of individual access
155
+ using Element = Element_;
156
+
157
+ /// Mapping function from logical coordinate to linear memory
158
+ using Layout = Layout_;
159
+
160
+ /// Reference type to an element
161
+ using Reference = typename platform::conditional<
162
+ sizeof_bits<Element>::value >= 8,
163
+ Element &,
164
+ SubbyteReference<Element>
165
+ >::type;
166
+
167
+ /// Logical rank of tensor index space
168
+ static int const kRank = Layout::kRank;
169
+
170
+ /// Index type
171
+ using Index = typename Layout::Index;
172
+
173
+ /// Long index used for pointer offsets
174
+ using LongIndex = typename Layout::LongIndex;
175
+
176
+ /// Coordinate in logical tensor space
177
+ using TensorCoord = typename Layout::TensorCoord;
178
+
179
+ /// Layout's stride vector
180
+ using Stride = typename Layout::Stride;
181
+
182
+ /// TensorRef to constant data
183
+ using ConstTensorRef = TensorRef<
184
+ typename platform::remove_const<Element>::type const,
185
+ Layout>;
186
+
187
+ /// TensorRef to non-constant data
188
+ using NonConstTensorRef = TensorRef<
189
+ typename platform::remove_const<Element>::type,
190
+ Layout>;
191
+
192
+ /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a
193
+ /// scalar, but degenerate cases such as these are difficult to accommodate without
194
+ /// extensive C++ metaprogramming or support for zero-length arrays.
195
+ static_assert(kRank > 0, "Cannot define a zero-rank TensorRef");
196
+
197
+ private:
198
+
199
+ /// Pointer
200
+ Element* ptr_;
201
+
202
+ /// Layout object maps logical coordinates to linear offsets
203
+ Layout layout_;
204
+
205
+ public:
206
+
207
+ //
208
+ // Methods
209
+ //
210
+
211
+ /// Constructs a TensorRef with a pointer and layout object.
212
+ CUTLASS_HOST_DEVICE
213
+ TensorRef(): ptr_(nullptr) {
214
+
215
+ }
216
+
217
+ /// Constructs a TensorRef with a pointer and layout object.
218
+ CUTLASS_HOST_DEVICE
219
+ TensorRef(
220
+ Element *ptr, ///< pointer to start of tensor
221
+ Layout const &layout ///< layout object containing stride and mapping function
222
+ ):
223
+ ptr_(ptr), layout_(layout) {
224
+
225
+ }
226
+
227
+ /// Converting constructor from TensorRef to non-constant data.
228
+ template<typename _Magic = int>
229
+ CUTLASS_HOST_DEVICE
230
+ TensorRef(
231
+ NonConstTensorRef const &ref, ///< TensorRef to non-const data
232
+ ///SFINAE trick to avoid creating a copy-constructor when Element_ is already non-const
233
+ _Magic magic = (typename platform::enable_if< ! platform::is_same<NonConstTensorRef, TensorRef<Element_, Layout_> >::value, _Magic>::type)0
234
+ ):
235
+ ptr_(ref.data()), layout_(ref.layout()) { }
236
+
237
+ /// Returns a reference to constant-valued tensor.
238
+ CUTLASS_HOST_DEVICE
239
+ ConstTensorRef const_ref() const {
240
+ return ConstTensorRef(ptr_, layout_);
241
+ }
242
+
243
+ CUTLASS_HOST_DEVICE
244
+ NonConstTensorRef non_const_ref() const {
245
+ return NonConstTensorRef(const_cast<typename platform::remove_const<Element>::type *>(ptr_), layout_);
246
+ }
247
+
248
+ /// Updates only the pointer
249
+ CUTLASS_HOST_DEVICE
250
+ void reset(Element* ptr = nullptr) {
251
+ ptr_ = ptr;
252
+ }
253
+
254
+ /// Updates the pointer and layout object
255
+ CUTLASS_HOST_DEVICE
256
+ void reset(Element* ptr, Layout const &layout) {
257
+ ptr_ = ptr;
258
+ layout_ = layout;
259
+ }
260
+
261
+ /// Returns true if the TensorRef is non-null
262
+ CUTLASS_HOST_DEVICE
263
+ bool good() const {
264
+ return ptr_ != nullptr;
265
+ }
266
+
267
+ /// Returns the pointer to referenced data
268
+ CUTLASS_HOST_DEVICE
269
+ Element * data() const { return ptr_; }
270
+
271
+ /// Returns a reference to the element at a given linear index
272
+ CUTLASS_HOST_DEVICE
273
+ Reference data(LongIndex idx) const {
274
+ return ReferenceFactory<typename platform::remove_const<Element>::type,
275
+ (sizeof_bits<Element>::value < 8)>::get(ptr_, idx);
276
+ }
277
+
278
+ /// Returns the layout object
279
+ CUTLASS_HOST_DEVICE
280
+ Layout & layout() {
281
+ return layout_;
282
+ }
283
+
284
+ /// Returns the layout object
285
+ CUTLASS_HOST_DEVICE
286
+ Layout layout() const {
287
+ return layout_;
288
+ }
289
+
290
+ /// Returns the layout object's stride vector
291
+ CUTLASS_HOST_DEVICE
292
+ Stride stride() const {
293
+ return layout_.stride();
294
+ }
295
+
296
+ /// Returns the layout object's stride vector
297
+ CUTLASS_HOST_DEVICE
298
+ Stride & stride() {
299
+ return layout_.stride();
300
+ }
301
+
302
+ /// Returns the layout object's stride in a given physical dimension
303
+ CUTLASS_HOST_DEVICE
304
+ typename Layout::Stride::Index stride(int dim) const {
305
+ return layout_.stride().at(dim);
306
+ }
307
+
308
+ /// Returns the layout object's stride in a given physical dimension
309
+ CUTLASS_HOST_DEVICE
310
+ typename Layout::Stride::Index & stride(int dim) {
311
+ return layout_.stride().at(dim);
312
+ }
313
+
314
+ /// Computes the offset of an index from the origin of the tensor
315
+ CUTLASS_HOST_DEVICE
316
+ LongIndex offset(TensorCoord const& coord) const {
317
+ return layout_(coord);
318
+ }
319
+
320
+ /// Returns a reference to the element at a given Coord
321
+ CUTLASS_HOST_DEVICE
322
+ Reference at(TensorCoord const& coord) const {
323
+ return data(offset(coord));
324
+ }
325
+
326
+ /// Returns a reference to the element at a given Coord
327
+ CUTLASS_HOST_DEVICE
328
+ Reference operator[](TensorCoord const& coord) const {
329
+ return data(offset(coord));
330
+ }
331
+
332
+ /// Adds an offset to each pointer
333
+ CUTLASS_HOST_DEVICE
334
+ TensorRef & add_pointer_offset(LongIndex offset_) {
335
+ ptr_ = ReferenceFactory<typename platform::remove_const<Element>::type,
336
+ (sizeof_bits<Element>::value < 8)>::add_pointer_offset(ptr_, offset_);
337
+ return *this;
338
+ }
339
+
340
+ /// Adds an offset to each pointer
341
+ CUTLASS_HOST_DEVICE
342
+ TensorRef & add_coord_offset(TensorCoord const &coord) {
343
+ add_pointer_offset(offset(coord));
344
+ return *this;
345
+ }
346
+
347
+ /// Returns a TensorRef offset by a given amount
348
+ CUTLASS_HOST_DEVICE
349
+ TensorRef operator+(TensorCoord const& b) const {
350
+ TensorRef result(*this);
351
+ result.add_coord_offset(b);
352
+ return result;
353
+ }
354
+
355
+ /// Returns a TensorRef offset by a given amount
356
+ CUTLASS_HOST_DEVICE
357
+ TensorRef & operator+=(TensorCoord const& b) {
358
+ add_coord_offset(b);
359
+ return *this;
360
+ }
361
+
362
+ /// Returns a TensorRef offset by a given amount
363
+ CUTLASS_HOST_DEVICE
364
+ TensorRef operator-(TensorCoord const& b) const {
365
+ TensorRef result(*this);
366
+ result.add_pointer_offset(-offset(b));
367
+ return result;
368
+ }
369
+
370
+ /// Returns a TensorRef offset by a given amount
371
+ CUTLASS_HOST_DEVICE
372
+ TensorRef & operator-=(TensorCoord const& b) {
373
+ add_pointer_offset(-offset(b));
374
+ return *this;
375
+ }
376
+ };
377
+
378
+ /// Constructs a TensorRef, deducing types from arguments.
379
+ template <
380
+ typename Element,
381
+ typename Layout
382
+ >
383
+ CUTLASS_HOST_DEVICE
384
+ TensorRef<Element, Layout> make_TensorRef(Element *ptr, Layout const &layout) {
385
+ return TensorRef<Element, Layout>(ptr, layout);
386
+ }
387
+
388
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
389
+ //
390
+ // Partial specializations to handle degenerate and sub-byte cases.
391
+ //
392
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
393
+
394
+ template <
395
+ typename Element,
396
+ typename Layout
397
+ >
398
+ CUTLASS_HOST_DEVICE
399
+ bool TensorRef_aligned(TensorRef<Element, Layout> const &ref, int alignment) {
400
+
401
+ int const kStrideRank = Layout::kStrideRank;
402
+
403
+ if (reinterpret_cast<uintptr_t>(ref.data()) % alignment) {
404
+ return false;
405
+ }
406
+
407
+ CUTLASS_PRAGMA_UNROLL
408
+ for (int i = 0; i < kStrideRank; ++i) {
409
+ if (ref.stride(i) % alignment) {
410
+ return false;
411
+ }
412
+ }
413
+
414
+ return true;
415
+ }
416
+
417
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
418
+
419
+ } // namespace cutlass
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_ref_planar_complex.h ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Defines a structure containing strides, bounds, and a pointer to tensor data.
33
+ */
34
+ #pragma once
35
+
36
+ #include <cstdint>
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/complex.h"
39
+ #include "cutlass/tensor_ref.h"
40
+
41
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
42
+
43
+ namespace cutlass {
44
+
45
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
46
+
47
+ template <typename Element_>
48
+ struct PlanarComplexReference {
49
+
50
+ //
51
+ // Type definitions
52
+ //
53
+
54
+ using Element = Element_;
55
+ using ComplexElement = complex<Element>;
56
+
57
+ //
58
+ // Data members
59
+ //
60
+
61
+ Element *real;
62
+ Element *imag;
63
+
64
+ //
65
+ // Methods
66
+ //
67
+
68
+ CUTLASS_HOST_DEVICE
69
+ PlanarComplexReference(
70
+ Element *real_ = nullptr,
71
+ Element *imag_ = nullptr
72
+ ):
73
+ real(real_), imag(imag_) { }
74
+
75
+ /// Loads the complex element
76
+ CUTLASS_HOST_DEVICE
77
+ operator complex<Element>() const {
78
+ return complex<Element>{*real, *imag};
79
+ }
80
+
81
+ /// Stores a complex element to the location pointed to by the reference
82
+ CUTLASS_HOST_DEVICE
83
+ PlanarComplexReference &operator=(complex<Element> const &rhs) {
84
+ *real = rhs.real();
85
+ *imag = rhs.imag();
86
+ return *this;
87
+ }
88
+ };
89
+
90
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
91
+
92
+ /* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank
93
+ and layout within memory. A TensorRef combines a pointer and a Layout concept
94
+
95
+ */
96
+ template <
97
+ /// Data type of element stored within tensor (concept: NumericType)
98
+ typename Element_,
99
+ /// Defines a mapping from logical coordinate to linear memory (concept: Layout)
100
+ typename Layout_
101
+ >
102
+ class TensorRefPlanarComplex {
103
+ public:
104
+ /// Data type of individual access
105
+ using Element = Element_;
106
+
107
+ /// Complex element type
108
+ using ComplexElement = complex<Element>;
109
+
110
+ /// Mapping function from logical coordinate to linear memory
111
+ using Layout = Layout_;
112
+
113
+ static_assert(sizeof_bits<Element>::value >= 8,
114
+ "Planar complex not suitable for subbyte elements at this time");
115
+
116
+ /// Reference type to an element
117
+ using Reference = PlanarComplexReference<Element>;
118
+
119
+ /// Logical rank of tensor index space
120
+ static int const kRank = Layout::kRank;
121
+
122
+ /// Index type
123
+ using Index = typename Layout::Index;
124
+
125
+ /// Long index used for pointer offsets
126
+ using LongIndex = typename Layout::LongIndex;
127
+
128
+ /// Coordinate in logical tensor space
129
+ using TensorCoord = typename Layout::TensorCoord;
130
+
131
+ /// Layout's stride vector
132
+ using Stride = typename Layout::Stride;
133
+
134
+ /// TensorRef to constant data
135
+ using ConstTensorRef = TensorRefPlanarComplex<
136
+ typename platform::remove_const<Element>::type const,
137
+ Layout>;
138
+
139
+ /// TensorRef to non-constant data
140
+ using NonConstTensorRef = TensorRefPlanarComplex<
141
+ typename platform::remove_const<Element>::type,
142
+ Layout>;
143
+
144
+ /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a
145
+ /// scalar, but degenerate cases such as these are difficult to accommodate without
146
+ /// extensive C++ metaprogramming or support for zero-length arrays.
147
+ static_assert(kRank > 0, "Cannot define a zero-rank TensorRef");
148
+
149
+ private:
150
+
151
+ /// Pointer
152
+ Element* ptr_;
153
+
154
+ /// Layout object maps logical coordinates to linear offsets
155
+ Layout layout_;
156
+
157
+ /// Offset to imaginary part
158
+ LongIndex imaginary_stride_;
159
+
160
+ public:
161
+
162
+ //
163
+ // Methods
164
+ //
165
+
166
+ /// Constructs a TensorRef with a pointer and layout object.
167
+ CUTLASS_HOST_DEVICE
168
+ TensorRefPlanarComplex(
169
+ Element *ptr = nullptr, ///< pointer to start of tensor
170
+ Layout const &layout = Layout(), ///< layout object containing stride and mapping function
171
+ LongIndex imaginary_stride = 0
172
+ ):
173
+ ptr_(ptr), layout_(layout), imaginary_stride_(imaginary_stride) {
174
+
175
+ }
176
+
177
+ /// Converting constructor from TensorRef to non-constant data.
178
+ CUTLASS_HOST_DEVICE
179
+ TensorRefPlanarComplex(
180
+ NonConstTensorRef const &ref ///< TensorRef to non-const data
181
+ ):
182
+ ptr_(ref.data()), layout_(ref.layout()), imaginary_stride_(ref.imaginary_stride_) { }
183
+
184
+ /// Returns a reference to constant-valued tensor.
185
+ CUTLASS_HOST_DEVICE
186
+ ConstTensorRef const_ref() const {
187
+ return ConstTensorRef(ptr_, layout_, imaginary_stride_);
188
+ }
189
+
190
+ CUTLASS_HOST_DEVICE
191
+ NonConstTensorRef non_const_ref() const {
192
+ return NonConstTensorRef(
193
+ const_cast<typename platform::remove_const<Element>::type *>(ptr_),
194
+ layout_,
195
+ imaginary_stride_);
196
+ }
197
+
198
+ /// Updates only the pointer
199
+ CUTLASS_HOST_DEVICE
200
+ void reset(Element* ptr = nullptr, LongIndex imaginary_stride = 0) {
201
+ ptr_ = ptr;
202
+ imaginary_stride_ = imaginary_stride;
203
+ }
204
+
205
+ /// Updates the pointer and layout object
206
+ CUTLASS_HOST_DEVICE
207
+ void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride) {
208
+ ptr_ = ptr;
209
+ layout_ = layout;
210
+ imaginary_stride_ = imaginary_stride;
211
+ }
212
+
213
+ /// Returns true if the TensorRef is non-null
214
+ CUTLASS_HOST_DEVICE
215
+ bool good() const {
216
+ return ptr_ != nullptr;
217
+ }
218
+
219
+ /// Returns the pointer to referenced data
220
+ CUTLASS_HOST_DEVICE
221
+ Element * data() const { return ptr_; }
222
+
223
+ /// Returns the pointer to referenced data
224
+ CUTLASS_HOST_DEVICE
225
+ Element * imaginary_data() const { return ptr_ + imaginary_stride_; }
226
+
227
+ /// Returns a reference to the element at a given linear index
228
+ CUTLASS_HOST_DEVICE
229
+ Reference data(LongIndex idx) const {
230
+ return Reference(ptr_ + idx, ptr_ + idx + imaginary_stride_);
231
+ }
232
+
233
+ /// Returns the layout object
234
+ CUTLASS_HOST_DEVICE
235
+ Layout & layout() {
236
+ return layout_;
237
+ }
238
+
239
+ /// Returns the layout object
240
+ CUTLASS_HOST_DEVICE
241
+ Layout layout() const {
242
+ return layout_;
243
+ }
244
+
245
+ /// Gets the stride to an imaginary element
246
+ LongIndex imaginary_stride() const {
247
+ return imaginary_stride_;
248
+ }
249
+
250
+ /// Gets the stride to an imaginary element
251
+ LongIndex &imaginary_stride() {
252
+ return imaginary_stride_;
253
+ }
254
+
255
+ /// Returns the layout object's stride vector
256
+ CUTLASS_HOST_DEVICE
257
+ Stride stride() const {
258
+ return layout_.stride();
259
+ }
260
+
261
+ /// Returns the layout object's stride vector
262
+ CUTLASS_HOST_DEVICE
263
+ Stride & stride() {
264
+ return layout_.stride();
265
+ }
266
+
267
+ /// Returns the layout object's stride in a given physical dimension
268
+ CUTLASS_HOST_DEVICE
269
+ Index stride(int dim) const {
270
+ return layout_.stride().at(dim);
271
+ }
272
+
273
+ /// Returns the layout object's stride in a given physical dimension
274
+ CUTLASS_HOST_DEVICE
275
+ Index & stride(int dim) {
276
+ return layout_.stride().at(dim);
277
+ }
278
+
279
+ /// Computes the offset of an index from the origin of the tensor
280
+ CUTLASS_HOST_DEVICE
281
+ LongIndex offset(TensorCoord const& coord) const {
282
+ return layout_(coord);
283
+ }
284
+
285
+ /// Returns a reference to the element at a given Coord
286
+ CUTLASS_HOST_DEVICE
287
+ Reference at(TensorCoord const& coord) const {
288
+ return data(offset(coord));
289
+ }
290
+
291
+ /// Returns a reference to the element at a given Coord
292
+ CUTLASS_HOST_DEVICE
293
+ Reference operator[](TensorCoord const& coord) const {
294
+ return data(offset(coord));
295
+ }
296
+
297
+ /// Adds an offset to each pointer
298
+ CUTLASS_HOST_DEVICE
299
+ TensorRefPlanarComplex & add_pointer_offset(LongIndex offset_) {
300
+ ptr_ += offset_;
301
+ return *this;
302
+ }
303
+
304
+ /// Adds an offset to each pointer
305
+ CUTLASS_HOST_DEVICE
306
+ TensorRefPlanarComplex & add_coord_offset(TensorCoord const &coord) {
307
+ add_pointer_offset(offset(coord));
308
+ return *this;
309
+ }
310
+
311
+ /// Returns a TensorRef offset by a given amount
312
+ CUTLASS_HOST_DEVICE
313
+ TensorRefPlanarComplex operator+(TensorCoord const& b) const {
314
+ TensorRefPlanarComplex result(*this);
315
+ result.add_coord_offset(b);
316
+ return result;
317
+ }
318
+
319
+ /// Returns a TensorRef offset by a given amount
320
+ CUTLASS_HOST_DEVICE
321
+ TensorRefPlanarComplex & operator+=(TensorCoord const& b) {
322
+ add_coord_offset(b);
323
+ return *this;
324
+ }
325
+
326
+ /// Returns a TensorRef offset by a given amount
327
+ CUTLASS_HOST_DEVICE
328
+ TensorRefPlanarComplex operator-(TensorCoord const& b) const {
329
+ TensorRefPlanarComplex result(*this);
330
+ result.add_pointer_offset(-offset(b));
331
+ return result;
332
+ }
333
+
334
+ /// Returns a TensorRef offset by a given amount
335
+ CUTLASS_HOST_DEVICE
336
+ TensorRefPlanarComplex & operator-=(TensorCoord const& b) {
337
+ add_pointer_offset(-offset(b));
338
+ return *this;
339
+ }
340
+
341
+ /// TensorRef to real-valued tensor
342
+ CUTLASS_HOST_DEVICE
343
+ cutlass::TensorRef<Element, Layout> ref_real() const {
344
+ return cutlass::TensorRef<Element, Layout>(data(), layout());
345
+ }
346
+
347
+ /// TensorRef to real-valued tensor
348
+ CUTLASS_HOST_DEVICE
349
+ cutlass::TensorRef<Element, Layout> ref_imag() const {
350
+ return cutlass::TensorRef<Element, Layout>(imaginary_data(), layout());
351
+ }
352
+ };
353
+
354
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
355
+
356
+ /// Constructs a TensorRef, deducing types from arguments.
357
+ template <
358
+ typename Element,
359
+ typename Layout
360
+ >
361
+ CUTLASS_HOST_DEVICE
362
+ TensorRefPlanarComplex<Element, Layout> make_TensorRefPlanarComplex(
363
+ Element *ptr,
364
+ Layout const &layout,
365
+ int64_t imaginary_stride) {
366
+
367
+ return TensorRefPlanarComplex<Element, Layout>(ptr, layout, imaginary_stride);
368
+ }
369
+
370
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
371
+
372
+ } // namespace cutlass
373
+
374
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_view.h ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Defines a structure containing strides and a pointer to tensor data.
33
+
34
+ TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus,
35
+ it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from
36
+ data storage and is therefore lightweight and may be embedded in larger tensor objects or
37
+ memory structures.
38
+
39
+ See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to
40
+ linear memory.
41
+ */
42
+
43
+ #pragma once
44
+
45
+ #if !defined(__CUDACC_RTC__)
46
+ #include <cmath>
47
+ #endif
48
+
49
+ #include "cutlass/cutlass.h"
50
+ #include "cutlass/tensor_ref.h"
51
+
52
+ namespace cutlass {
53
+
54
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
55
+
56
+ template <
57
+ /// Data type of element stored within tensor
58
+ typename Element_,
59
+ /// Maps a Coord<Rank_> in the logical tensor index space to the internal n-D array
60
+ typename Layout_
61
+ >
62
+ class TensorView : public TensorRef<Element_, Layout_> {
63
+ public:
64
+
65
+ /// Base tensor reference
66
+ using Base = cutlass::TensorRef<Element_, Layout_>;
67
+
68
+ /// Mapping function from logical coordinate to internal n-D array
69
+ using Layout = Layout_;
70
+
71
+ /// TensorRef pointing to constant memory
72
+ using ConstTensorRef = typename Base::ConstTensorRef;
73
+
74
+ /// Underlying TensorRef type
75
+ using TensorRef = Base;
76
+
77
+ /// Data type of individual access
78
+ using Element = Element_;
79
+
80
+ /// Reference type to an element
81
+ using Reference = Element &;
82
+
83
+ /// Logical rank of tensor index space
84
+ static int const kRank = Layout::kRank;
85
+
86
+ /// Index type
87
+ using Index = typename Layout::Index;
88
+
89
+ /// Long index used for pointer offsets
90
+ using LongIndex = typename Layout::LongIndex;
91
+
92
+ /// Coordinate in logical tensor space
93
+ using TensorCoord = typename Layout::TensorCoord;
94
+
95
+ /// Coordinate in storage n-D array
96
+ using Stride = typename Layout::Stride;
97
+
98
+ /// TensorView pointing to constant memory
99
+ using ConstTensorView = TensorView<
100
+ typename platform::remove_const<Element>::type const,
101
+ Layout>;
102
+
103
+ /// TensorView pointing to non-constant memory
104
+ using NonConstTensorView = TensorView<
105
+ typename platform::remove_const<Element>::type,
106
+ Layout>;
107
+
108
+ /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a
109
+ /// scalar, but degenerate cases such as these are difficult to accommodate without
110
+ /// extensive C++ metaprogramming or support for zero-length arrays.
111
+ static_assert(kRank > 0, "Cannot define a zero-rank TensorRef");
112
+
113
+ private:
114
+
115
+ /// View extent
116
+ TensorCoord extent_;
117
+
118
+ public:
119
+
120
+ //
121
+ // Methods
122
+ //
123
+
124
+ /// Constructs a TensorView object
125
+ CUTLASS_HOST_DEVICE
126
+ TensorView() { }
127
+
128
+ /// Constructs a TensorView object
129
+ CUTLASS_HOST_DEVICE
130
+ TensorView(
131
+ Element *ptr, ///< pointer to start of tensor
132
+ Layout const &layout, ///< layout object containing stride and mapping function
133
+ TensorCoord const &extent ///< size of the view in logical coordinates
134
+ ):
135
+ Base(ptr, layout), extent_(extent) {
136
+
137
+ }
138
+
139
+ /// Constructs a TensorView object
140
+ CUTLASS_HOST_DEVICE
141
+ TensorView(
142
+ TensorRef const &ref, ///< pointer and layout object referencing a tensor
143
+ TensorCoord const &extent ///< logical size of tensor
144
+ ):
145
+ Base(ref), extent_(extent) {
146
+
147
+ }
148
+
149
+ /// Converting constructor from TensorRef to non-constant data.
150
+ CUTLASS_HOST_DEVICE
151
+ TensorView(
152
+ NonConstTensorView const &view ///< TensorView to non-const data
153
+ ):
154
+ Base(view), extent_(view.extent_) { }
155
+
156
+ /// Updates the pointer and layout object
157
+ CUTLASS_HOST_DEVICE
158
+ void reset(Element* ptr, Layout const &layout, TensorCoord const &extent) {
159
+ Base::reset(ptr, layout);
160
+ this->resize(extent);
161
+ }
162
+
163
+ /// Updates the pointer
164
+ CUTLASS_HOST_DEVICE
165
+ void reset(Element* ptr) {
166
+ Base::reset(ptr);
167
+ }
168
+
169
+ /// Changes the size of the view without affecting pointer or layout
170
+ CUTLASS_HOST_DEVICE
171
+ void resize(TensorCoord const &extent) {
172
+ this->extent_ = extent;
173
+ }
174
+
175
+ /// Returns the extent of the view (the size along each logical dimension).
176
+ CUTLASS_HOST_DEVICE
177
+ TensorCoord const& extent() const { return extent_; }
178
+
179
+ /// Returns the extent along a particular logical dimension.
180
+ CUTLASS_HOST_DEVICE
181
+ Index extent(int dim) const { return extent_.at(dim); }
182
+
183
+ /// Returns the number of logical elements
184
+ CUTLASS_HOST_DEVICE
185
+ LongIndex size() const {
186
+ return extent_.product();
187
+ }
188
+
189
+ /// Determines whether a location is within a tensor
190
+ CUTLASS_HOST_DEVICE
191
+ bool contains(TensorCoord const& coord) const {
192
+ CUTLASS_PRAGMA_UNROLL
193
+ for (int dim = 0; dim < kRank; ++dim) {
194
+ if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) {
195
+ return false;
196
+ }
197
+ }
198
+ return true;
199
+ }
200
+
201
+ /// Returns a TensorRef pointing to the first element of the tensor.
202
+ CUTLASS_HOST_DEVICE
203
+ TensorRef ref() const {
204
+ return TensorRef(this->data(), this->layout());
205
+ }
206
+
207
+ /// Returns a TensorRef pointing to the first element of the tensor.
208
+ CUTLASS_HOST_DEVICE
209
+ ConstTensorRef const_ref() const {
210
+ return ConstTensorRef(this->data(), this->layout());
211
+ }
212
+
213
+ /// Returns a TensorView to const data
214
+ CUTLASS_HOST_DEVICE
215
+ ConstTensorView const_view() const {
216
+ return ConstTensorView(const_ref(), extent_);
217
+ }
218
+
219
+ /// Returns a Tensor_view given location and size quantities
220
+ CUTLASS_HOST_DEVICE
221
+ TensorView subview(
222
+ TensorCoord extent, ///< extent of the resulting view
223
+ TensorCoord const& location = TensorCoord() ///< resulting view's origin within the old view
224
+ ) const {
225
+
226
+ TensorView result(this->ref(), extent.clamp(extent_ - location));
227
+ result.add_coord_offset(location);
228
+ return result;
229
+ }
230
+
231
+ /// Returns the number of scalar elements needed to store tensor.
232
+ CUTLASS_HOST_DEVICE
233
+ size_t capacity() const {
234
+ return Base::layout().capacity(extent_);
235
+ }
236
+
237
+ /// Returns a TensorView offset by a given amount
238
+ CUTLASS_HOST_DEVICE
239
+ TensorView operator+(
240
+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor
241
+ ) const {
242
+
243
+ TensorView result(*this);
244
+ result.add_pointer_offset(this->offset(b));
245
+ return result;
246
+ }
247
+
248
+ /// Returns a TensorRef offset by a given amount
249
+ CUTLASS_HOST_DEVICE
250
+ TensorView& operator+=(
251
+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor
252
+ ) {
253
+
254
+ this->add_pointer_offset(this->offset(b));
255
+ return *this;
256
+ }
257
+
258
+ /// Returns a TensorRef offset by a given amount
259
+ CUTLASS_HOST_DEVICE
260
+ TensorView operator-(
261
+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor
262
+ ) const {
263
+
264
+ TensorRef result(*this);
265
+ result.add_pointer_offset(-this->offset(b));
266
+ return result;
267
+ }
268
+
269
+ /// Returns a TensorRef offset by a given amount
270
+ CUTLASS_HOST_DEVICE
271
+ TensorView& operator-=(
272
+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor
273
+ ) {
274
+
275
+ this->add_pointer_offset(-this->offset(b));
276
+ return *this;
277
+ }
278
+ };
279
+
280
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
281
+
282
+ /// Constructs a TensorRef, deducing types from arguments.
283
+ template <
284
+ typename Element,
285
+ typename Layout
286
+ >
287
+ CUTLASS_HOST_DEVICE TensorView<Element, Layout> make_TensorView(
288
+ Element *ptr,
289
+ Layout const &layout,
290
+ typename Layout::TensorCoord const &extent) {
291
+
292
+ return TensorView<Element, Layout>(ptr, layout, extent);
293
+ }
294
+
295
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
296
+
297
+ } // namespace cutlass
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_view_planar_complex.h ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Defines a structure containing strides and a pointer to tensor data.
33
+
34
+ TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus,
35
+ it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from
36
+ data storage and is therefore lightweight and may be embedded in larger tensor objects or
37
+ memory structures.
38
+
39
+ See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to
40
+ linear memory.
41
+ */
42
+
43
+ #pragma once
44
+
45
+ #if !defined(__CUDACC_RTC__)
46
+ #include <cmath>
47
+ #endif
48
+
49
+ #include "cutlass/cutlass.h"
50
+ #include "cutlass/tensor_ref_planar_complex.h"
51
+ #include "cutlass/tensor_view.h" // cutlass::TensorView
52
+
53
+ namespace cutlass {
54
+
55
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
56
+
57
+ template <
58
+ /// Data type of element stored within tensor
59
+ typename Element_,
60
+ /// Maps a Coord<Rank_> in the logical tensor index space to the internal n-D array
61
+ typename Layout_
62
+ >
63
+ class TensorViewPlanarComplex : public TensorRefPlanarComplex<Element_, Layout_> {
64
+ public:
65
+
66
+ /// Base tensor reference
67
+ using Base = cutlass::TensorRefPlanarComplex<Element_, Layout_>;
68
+
69
+ /// Mapping function from logical coordinate to internal n-D array
70
+ using Layout = Layout_;
71
+
72
+ /// TensorRef pointing to constant memory
73
+ using ConstTensorRef = typename Base::ConstTensorRef;
74
+
75
+ /// Underlying TensorRef type
76
+ using TensorRef = Base;
77
+
78
+ /// Data type of individual access
79
+ using Element = Element_;
80
+
81
+ /// Reference type to an element
82
+ using Reference = Element &;
83
+
84
+ /// Logical rank of tensor index space
85
+ static int const kRank = Layout::kRank;
86
+
87
+ /// Index type
88
+ using Index = typename Layout::Index;
89
+
90
+ /// Long index used for pointer offsets
91
+ using LongIndex = typename Layout::LongIndex;
92
+
93
+ /// Coordinate in logical tensor space
94
+ using TensorCoord = typename Layout::TensorCoord;
95
+
96
+ /// Coordinate in storage n-D array
97
+ using Stride = typename Layout::Stride;
98
+
99
+ /// TensorView pointing to constant memory
100
+ using ConstTensorView = TensorViewPlanarComplex<
101
+ typename platform::remove_const<Element>::type const,
102
+ Layout>;
103
+
104
+ /// TensorView pointing to non-constant memory
105
+ using NonConstTensorView = TensorViewPlanarComplex<
106
+ typename platform::remove_const<Element>::type,
107
+ Layout>;
108
+
109
+ /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a
110
+ /// scalar, but degenerate cases such as these are difficult to accommodate without
111
+ /// extensive C++ metaprogramming or support for zero-length arrays.
112
+ static_assert(kRank > 0, "Cannot define a zero-rank TensorRef");
113
+
114
+ private:
115
+
116
+ /// View extent
117
+ TensorCoord extent_;
118
+
119
+ public:
120
+
121
+ //
122
+ // Methods
123
+ //
124
+
125
+ /// Constructs a TensorView object
126
+ CUTLASS_HOST_DEVICE
127
+ TensorViewPlanarComplex(TensorCoord const &extent = TensorCoord()): extent_(extent) {
128
+
129
+ }
130
+
131
+ /// Constructs a TensorView object
132
+ CUTLASS_HOST_DEVICE
133
+ TensorViewPlanarComplex(
134
+ Element *ptr, ///< pointer to start of tensor
135
+ Layout const &layout, ///< layout object containing stride and mapping function
136
+ LongIndex imaginary_stride, ///< stride between real and imaginary part
137
+ TensorCoord const &extent ///< size of the view in logical coordinates
138
+ ):
139
+ Base(ptr, layout, imaginary_stride), extent_(extent) {
140
+
141
+ }
142
+
143
+ /// Constructs a TensorView object
144
+ CUTLASS_HOST_DEVICE
145
+ TensorViewPlanarComplex(
146
+ TensorRef const &ref, ///< pointer and layout object referencing a tensor
147
+ TensorCoord const &extent ///< logical size of tensor
148
+ ):
149
+ Base(ref), extent_(extent) {
150
+
151
+ }
152
+
153
+ /// Converting constructor from TensorRef to non-constant data.
154
+ CUTLASS_HOST_DEVICE
155
+ TensorViewPlanarComplex(
156
+ NonConstTensorView const &view ///< TensorView to non-const data
157
+ ):
158
+ Base(view), extent_(view.extent_) { }
159
+
160
+ /// Updates the pointer and layout object
161
+ CUTLASS_HOST_DEVICE
162
+ void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride, TensorCoord size) {
163
+ Base::reset(ptr, layout, imaginary_stride);
164
+ this->resize(extent_);
165
+ }
166
+
167
+ /// Changes the size of the view without affecting pointer or layout
168
+ CUTLASS_HOST_DEVICE
169
+ void resize(TensorCoord extent) {
170
+ this->extent_ = extent;
171
+ }
172
+
173
+ /// Returns the extent of the view (the size along each logical dimension).
174
+ CUTLASS_HOST_DEVICE
175
+ TensorCoord const& extent() const { return extent_; }
176
+
177
+ /// Returns the extent along a particular logical dimension.
178
+ CUTLASS_HOST_DEVICE
179
+ Index extent(int dim) const { return extent_.at(dim); }
180
+
181
+ /// Determines whether a location is within a tensor
182
+ CUTLASS_HOST_DEVICE
183
+ bool contains(TensorCoord const& coord) const {
184
+ CUTLASS_PRAGMA_UNROLL
185
+ for (int dim = 0; dim < kRank; ++dim) {
186
+ if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) {
187
+ return false;
188
+ }
189
+ }
190
+ return true;
191
+ }
192
+
193
+ /// Returns a TensorRef pointing to the first element of the tensor.
194
+ CUTLASS_HOST_DEVICE
195
+ Base ref() const {
196
+ return Base(this->data(), this->layout(), this->imaginary_stride());
197
+ }
198
+
199
+ /// Returns a TensorRef pointing to the first element of the tensor.
200
+ CUTLASS_HOST_DEVICE
201
+ ConstTensorRef const_ref() const {
202
+ return ConstTensorRef(this->data(), this->layout());
203
+ }
204
+
205
+ /// Returns a TensorView to const data
206
+ CUTLASS_HOST_DEVICE
207
+ ConstTensorView const_view() const {
208
+ return ConstTensorView(const_ref(), extent_);
209
+ }
210
+
211
+ /// Returns a Tensor_view given location and size quantities
212
+ CUTLASS_HOST_DEVICE
213
+ TensorViewPlanarComplex subview(
214
+ TensorCoord extent, ///< extent of the resulting view
215
+ TensorCoord const& location = TensorCoord() ///< resulting view's origin within the old view
216
+ ) const {
217
+
218
+ TensorViewPlanarComplex result(this->ref(), extent.clamp(extent_ - location));
219
+ result.add_coord_offset(location);
220
+ return result;
221
+ }
222
+
223
+ /// Returns the number of scalar elements needed to store tensor.
224
+ CUTLASS_HOST_DEVICE
225
+ size_t capacity() const {
226
+ return Base::layout().capacity(extent_);
227
+ }
228
+
229
+ /// Returns a TensorView offset by a given amount
230
+ CUTLASS_HOST_DEVICE
231
+ TensorViewPlanarComplex operator+(
232
+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor
233
+ ) const {
234
+
235
+ TensorViewPlanarComplex result(*this);
236
+ result.add_pointer_offset(this->offset(b));
237
+ return result;
238
+ }
239
+
240
+ /// Returns a TensorRef offset by a given amount
241
+ CUTLASS_HOST_DEVICE
242
+ TensorViewPlanarComplex& operator+=(
243
+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor
244
+ ) {
245
+
246
+ this->add_pointer_offset(this->offset(b));
247
+ return *this;
248
+ }
249
+
250
+ /// Returns a TensorRef offset by a given amount
251
+ CUTLASS_HOST_DEVICE
252
+ TensorViewPlanarComplex operator-(
253
+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor
254
+ ) const {
255
+
256
+ TensorRef result(*this);
257
+ result.add_pointer_offset(-this->offset(b));
258
+ return result;
259
+ }
260
+
261
+ /// Returns a TensorRef offset by a given amount
262
+ CUTLASS_HOST_DEVICE
263
+ TensorViewPlanarComplex& operator-=(
264
+ TensorCoord const& b ///< offset in the logical coordinate space of the tensor
265
+ ) {
266
+
267
+ this->add_pointer_offset(-this->offset(b));
268
+ return *this;
269
+ }
270
+
271
+ /// TensorRef to real-valued tensor
272
+ CUTLASS_HOST_DEVICE
273
+ cutlass::TensorView<Element, Layout> view_real() const {
274
+ return cutlass::TensorView<Element, Layout>(this->data(), this->layout(), extent_);
275
+ }
276
+
277
+ /// TensorRef to real-valued tensor
278
+ CUTLASS_HOST_DEVICE
279
+ cutlass::TensorView<Element, Layout> view_imag() const {
280
+ return cutlass::TensorView<Element, Layout>(this->imaginary_data(), this->layout(), extent_);
281
+ }
282
+ };
283
+
284
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
285
+
286
+ /// Constructs a TensorRef, deducing types from arguments.
287
+ template <
288
+ typename Element,
289
+ typename Layout
290
+ >
291
+ CUTLASS_HOST_DEVICE TensorViewPlanarComplex<Element, Layout> make_TensorViewPlanarComplex(
292
+ Element *ptr,
293
+ Layout const &layout,
294
+ typename Layout::LongIndex imaginary_stride,
295
+ typename Layout::TensorCoord const &extent) {
296
+
297
+ return TensorViewPlanarComplex<Element, Layout>(ptr, layout, imaginary_stride, extent);
298
+ }
299
+
300
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
301
+
302
+ } // namespace cutlass
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tfloat32.h ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*!
32
+ \file
33
+ \brief Defines a proxy class for storing Tensor Float 32 data type.
34
+ */
35
+ #pragma once
36
+
37
+ #if defined(__CUDACC_RTC__)
38
+ #include "cutlass/floating_point_nvrtc.h"
39
+ #else
40
+ #include <cmath>
41
+ #include <limits>
42
+ #include <cstdint>
43
+ #include <cstring> // std::memcpy
44
+ #endif
45
+
46
+ #include "cutlass/cutlass.h"
47
+
48
+ namespace cutlass {
49
+
50
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
51
+
52
+ /// Tensor Float 32 data type
53
+ struct alignas(4) tfloat32_t {
54
+
55
+ //
56
+ // Data members
57
+ //
58
+
59
+ /// Storage type
60
+ uint32_t storage;
61
+
62
+ //
63
+ // Methods
64
+ //
65
+ private:
66
+ CUTLASS_HOST_DEVICE
67
+ static uint32_t float_to_storage(float s) {
68
+ #if defined(__CUDA_ARCH__)
69
+ uint32_t result = reinterpret_cast<uint32_t const &>(s);
70
+ #else
71
+ uint32_t result;
72
+ std::memcpy(&result, &s, sizeof(float));
73
+ #endif
74
+ return result;
75
+ }
76
+
77
+ public:
78
+ /// Constructs from an unsigned int
79
+ CUTLASS_HOST_DEVICE
80
+ static tfloat32_t bitcast(uint32_t x) {
81
+ tfloat32_t h;
82
+ h.storage = x;
83
+ return h;
84
+ }
85
+
86
+ /// Emulated rounding is fast in device code
87
+ CUTLASS_HOST_DEVICE
88
+ static tfloat32_t round_half_ulp_truncate(float const &s) {
89
+ uint32_t x = float_to_storage(s);
90
+
91
+ #if defined(__CUDA_ARCH__)
92
+ if (::isfinite(s)) {
93
+ x += 0x1000u;
94
+ }
95
+ #else
96
+ if (std::isfinite(s)) {
97
+ x += 0x1000u;
98
+ }
99
+ #endif
100
+
101
+ return tfloat32_t::bitcast(x);
102
+ }
103
+
104
+ tfloat32_t() = default;
105
+
106
+ /// Floating-point conversion - round toward nearest even
107
+ CUTLASS_HOST_DEVICE
108
+ explicit tfloat32_t(float x): storage(round_half_ulp_truncate(x).raw()) { }
109
+
110
+ // Conversion from double (this rounds twice)
111
+ CUTLASS_HOST_DEVICE
112
+ explicit tfloat32_t(double x): tfloat32_t(float(x)) { }
113
+
114
+ /// Integer conversion - round toward zero
115
+ CUTLASS_HOST_DEVICE
116
+ explicit tfloat32_t(int x) {
117
+ float flt = static_cast<float>(x);
118
+ #if defined(__CUDA_ARCH__)
119
+ storage = reinterpret_cast<uint32_t const &>(flt);
120
+ #else
121
+ std::memcpy(&storage, &flt, sizeof(storage));
122
+ #endif
123
+ }
124
+
125
+ // Conversion to float
126
+ CUTLASS_HOST_DEVICE
127
+ operator float() const {
128
+
129
+ // Conversions to IEEE single-precision requires clearing dont-care bits
130
+ // of the mantissa.
131
+ unsigned bits = (storage & ~0x1fffu);
132
+
133
+ #if defined(__CUDA_ARCH__)
134
+ return reinterpret_cast<float const &>(bits);
135
+ #else
136
+ float flt;
137
+ std::memcpy(&flt, &bits, sizeof(flt));
138
+ return flt;
139
+ #endif
140
+ }
141
+
142
+ /// Converts to double
143
+ CUTLASS_HOST_DEVICE
144
+ explicit operator double() const {
145
+ return double(float(*this));
146
+ }
147
+
148
+ /// Converts to int
149
+ CUTLASS_HOST_DEVICE
150
+ explicit operator int() const {
151
+ return int(float(*this));
152
+ }
153
+
154
+ /// Casts to bool
155
+ CUTLASS_HOST_DEVICE
156
+ explicit operator bool() const {
157
+ return (float(*this) != 0.0f);
158
+ }
159
+
160
+ /// Obtains raw bits
161
+ CUTLASS_HOST_DEVICE
162
+ uint32_t raw() const {
163
+ return storage;
164
+ }
165
+
166
+ /// Returns the sign bit
167
+ CUTLASS_HOST_DEVICE
168
+ bool signbit() const {
169
+ return ((raw() & 0x80000000) != 0);
170
+ }
171
+
172
+ /// Returns the biased exponent
173
+ CUTLASS_HOST_DEVICE
174
+ int exponent_biased() const {
175
+ return int((raw() >> 23) & 0x0ff);
176
+ }
177
+
178
+ /// Returns the unbiased exponent
179
+ CUTLASS_HOST_DEVICE
180
+ int exponent() const {
181
+ return exponent_biased() - 127;
182
+ }
183
+
184
+ /// Returns the mantissa
185
+ CUTLASS_HOST_DEVICE
186
+ int mantissa() const {
187
+ return int(raw() & 0x7fffff);
188
+ }
189
+ };
190
+
191
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
192
+
193
+ CUTLASS_HOST_DEVICE
194
+ bool signbit(cutlass::tfloat32_t const& h) {
195
+ return h.signbit();
196
+ }
197
+
198
+ CUTLASS_HOST_DEVICE
199
+ cutlass::tfloat32_t abs(cutlass::tfloat32_t const& h) {
200
+ return cutlass::tfloat32_t::bitcast(h.raw() & 0x7fffffff);
201
+ }
202
+
203
+ CUTLASS_HOST_DEVICE
204
+ bool isnan(cutlass::tfloat32_t const& h) {
205
+ return (h.exponent_biased() == 0x0ff) && h.mantissa();
206
+ }
207
+
208
+ CUTLASS_HOST_DEVICE
209
+ bool isfinite(cutlass::tfloat32_t const& h) {
210
+ return (h.exponent_biased() != 0x0ff);
211
+ }
212
+
213
+ CUTLASS_HOST_DEVICE
214
+ cutlass::tfloat32_t nan_tf32(const char*) {
215
+ // NVIDIA canonical NaN
216
+ return cutlass::tfloat32_t::bitcast(0x7fffffff);
217
+ }
218
+
219
+ CUTLASS_HOST_DEVICE
220
+ bool isinf(cutlass::tfloat32_t const& h) {
221
+ return (h.exponent_biased() == 0x0ff) && !h.mantissa();
222
+ }
223
+
224
+ CUTLASS_HOST_DEVICE
225
+ bool isnormal(cutlass::tfloat32_t const& h) {
226
+ return h.exponent_biased() && h.exponent_biased() != 0x0ff;
227
+ }
228
+
229
+ CUTLASS_HOST_DEVICE
230
+ int fpclassify(cutlass::tfloat32_t const& h) {
231
+ int exp = h.exponent_biased();
232
+ int mantissa = h.mantissa();
233
+ if (exp == 0x0ff) {
234
+ if (mantissa) {
235
+ return FP_NAN;
236
+ }
237
+ else {
238
+ return FP_INFINITE;
239
+ }
240
+ }
241
+ else if (!exp) {
242
+ if (mantissa) {
243
+ return FP_SUBNORMAL;
244
+ }
245
+ else {
246
+ return FP_ZERO;
247
+ }
248
+ }
249
+ return FP_NORMAL;
250
+ }
251
+
252
+ CUTLASS_HOST_DEVICE
253
+ cutlass::tfloat32_t sqrt(cutlass::tfloat32_t const& h) {
254
+ #if defined(__CUDACC_RTC__)
255
+ return cutlass::tfloat32_t(sqrtf(float(h)));
256
+ #else
257
+ return cutlass::tfloat32_t(std::sqrt(float(h)));
258
+ #endif
259
+ }
260
+
261
+ CUTLASS_HOST_DEVICE
262
+ tfloat32_t copysign(tfloat32_t const& a, tfloat32_t const& b) {
263
+
264
+ uint32_t a_mag = (a.raw() & 0x7fffffff);
265
+ uint32_t b_sign = (b.raw() & 0x80000000);
266
+ uint32_t result = (a_mag | b_sign);
267
+
268
+ return tfloat32_t::bitcast(result);
269
+ }
270
+
271
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
272
+
273
+ } // namespace cutlass
274
+
275
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
276
+ //
277
+ // Standard Library operations and definitions
278
+ //
279
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
280
+
281
+ namespace std {
282
+
283
+ #if !defined(__CUDACC_RTC__)
284
+ /// Numeric limits
285
+ template <>
286
+ struct numeric_limits<cutlass::tfloat32_t> {
287
+ static bool const is_specialized = true;
288
+ static bool const is_signed = true;
289
+ static bool const is_integer = false;
290
+ static bool const is_exact = false;
291
+ static bool const has_infinity = true;
292
+ static bool const has_quiet_NaN = true;
293
+ static bool const has_signaling_NaN = false;
294
+ static std::float_denorm_style const has_denorm = std::denorm_present;
295
+ static bool const has_denorm_loss = true;
296
+ static std::float_round_style const round_style = std::round_to_nearest;
297
+ static bool const is_iec559 = false;
298
+ static bool const is_bounded = true;
299
+ static bool const is_modulo = false;
300
+ static int const digits = 19;
301
+
302
+ /// Least positive value
303
+ static cutlass::tfloat32_t min() { return cutlass::tfloat32_t::bitcast(0x01); }
304
+
305
+ /// Minimum finite value
306
+ static cutlass::tfloat32_t lowest() { return cutlass::tfloat32_t::bitcast(0xff7fffff); }
307
+
308
+ /// Maximum finite value
309
+ static cutlass::tfloat32_t max() { return cutlass::tfloat32_t::bitcast(0x7f7fffff); }
310
+
311
+ /// Returns smallest finite value
312
+ static cutlass::tfloat32_t epsilon() { return cutlass::tfloat32_t::bitcast(0x1000); }
313
+
314
+ /// Returns smallest finite value
315
+ static cutlass::tfloat32_t round_error() { return cutlass::tfloat32_t(0.5f); }
316
+
317
+ /// Returns smallest finite value
318
+ static cutlass::tfloat32_t infinity() { return cutlass::tfloat32_t::bitcast(0x7f800000); }
319
+
320
+ /// Returns smallest finite value
321
+ static cutlass::tfloat32_t quiet_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); }
322
+
323
+ /// Returns smallest finite value
324
+ static cutlass::tfloat32_t signaling_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); }
325
+
326
+ /// Returns smallest finite value
327
+ static cutlass::tfloat32_t denorm_min() { return cutlass::tfloat32_t::bitcast(0x1); }
328
+ };
329
+ #endif
330
+
331
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
332
+
333
+ } // namespace std
334
+
335
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
336
+ //
337
+ // Arithmetic operators
338
+ //
339
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
340
+
341
+ namespace cutlass {
342
+
343
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
344
+
345
+ CUTLASS_HOST_DEVICE
346
+ bool operator==(tfloat32_t const& lhs, tfloat32_t const& rhs) {
347
+ return float(lhs) == float(rhs);
348
+ }
349
+
350
+ CUTLASS_HOST_DEVICE
351
+ bool operator!=(tfloat32_t const& lhs, tfloat32_t const& rhs) {
352
+ return float(lhs) != float(rhs);
353
+ }
354
+
355
+ CUTLASS_HOST_DEVICE
356
+ bool operator<(tfloat32_t const& lhs, tfloat32_t const& rhs) {
357
+ return float(lhs) < float(rhs);
358
+ }
359
+
360
+ CUTLASS_HOST_DEVICE
361
+ bool operator<=(tfloat32_t const& lhs, tfloat32_t const& rhs) {
362
+ return float(lhs) <= float(rhs);
363
+ }
364
+
365
+ CUTLASS_HOST_DEVICE
366
+ bool operator>(tfloat32_t const& lhs, tfloat32_t const& rhs) {
367
+ return float(lhs) > float(rhs);
368
+ }
369
+
370
+ CUTLASS_HOST_DEVICE
371
+ bool operator>=(tfloat32_t const& lhs, tfloat32_t const& rhs) {
372
+ return float(lhs) >= float(rhs);
373
+ }
374
+
375
+ CUTLASS_HOST_DEVICE
376
+ tfloat32_t operator+(tfloat32_t const& lhs, tfloat32_t const& rhs) {
377
+ return tfloat32_t(float(lhs) + float(rhs));
378
+ }
379
+
380
+
381
+ CUTLASS_HOST_DEVICE
382
+ tfloat32_t operator-(tfloat32_t const& lhs) {
383
+ return tfloat32_t::bitcast(0x80000000 ^ lhs.raw());
384
+ }
385
+
386
+ CUTLASS_HOST_DEVICE
387
+ tfloat32_t operator-(tfloat32_t const& lhs, tfloat32_t const& rhs) {
388
+ return tfloat32_t(float(lhs) - float(rhs));
389
+ }
390
+
391
+ CUTLASS_HOST_DEVICE
392
+ tfloat32_t operator*(tfloat32_t const& lhs, tfloat32_t const& rhs) {
393
+ return tfloat32_t(float(lhs) * float(rhs));
394
+ }
395
+
396
+ CUTLASS_HOST_DEVICE
397
+ tfloat32_t operator/(tfloat32_t const& lhs, tfloat32_t const& rhs) {
398
+ return tfloat32_t(float(lhs) / float(rhs));
399
+ }
400
+
401
+ CUTLASS_HOST_DEVICE
402
+ tfloat32_t& operator+=(tfloat32_t & lhs, tfloat32_t const& rhs) {
403
+ lhs = tfloat32_t(float(lhs) + float(rhs));
404
+ return lhs;
405
+ }
406
+
407
+ CUTLASS_HOST_DEVICE
408
+ tfloat32_t& operator-=(tfloat32_t & lhs, tfloat32_t const& rhs) {
409
+ lhs = tfloat32_t(float(lhs) - float(rhs));
410
+ return lhs;
411
+ }
412
+
413
+ CUTLASS_HOST_DEVICE
414
+ tfloat32_t& operator*=(tfloat32_t & lhs, tfloat32_t const& rhs) {
415
+ lhs = tfloat32_t(float(lhs) * float(rhs));
416
+ return lhs;
417
+ }
418
+
419
+ CUTLASS_HOST_DEVICE
420
+ tfloat32_t& operator/=(tfloat32_t & lhs, tfloat32_t const& rhs) {
421
+ lhs = tfloat32_t(float(lhs) / float(rhs));
422
+ return lhs;
423
+ }
424
+
425
+ CUTLASS_HOST_DEVICE
426
+ tfloat32_t& operator++(tfloat32_t & lhs) {
427
+ float tmp(lhs);
428
+ ++tmp;
429
+ lhs = tfloat32_t(tmp);
430
+ return lhs;
431
+ }
432
+
433
+ CUTLASS_HOST_DEVICE
434
+ tfloat32_t& operator--(tfloat32_t & lhs) {
435
+ float tmp(lhs);
436
+ --tmp;
437
+ lhs = tfloat32_t(tmp);
438
+ return lhs;
439
+ }
440
+
441
+ CUTLASS_HOST_DEVICE
442
+ tfloat32_t operator++(tfloat32_t & lhs, int) {
443
+ tfloat32_t ret(lhs);
444
+ float tmp(lhs);
445
+ tmp++;
446
+ lhs = tfloat32_t(tmp);
447
+ return ret;
448
+ }
449
+
450
+ CUTLASS_HOST_DEVICE
451
+ tfloat32_t operator--(tfloat32_t & lhs, int) {
452
+ tfloat32_t ret(lhs);
453
+ float tmp(lhs);
454
+ tmp--;
455
+ lhs = tfloat32_t(tmp);
456
+ return ret;
457
+ }
458
+
459
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
460
+
461
+ } // namespace cutlass
462
+
463
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
464
+
465
+ //
466
+ // User-defined literals
467
+ //
468
+
469
+ CUTLASS_HOST_DEVICE
470
+ cutlass::tfloat32_t operator "" _tf32(long double x) {
471
+ return cutlass::tfloat32_t(float(x));
472
+ }
473
+
474
+ CUTLASS_HOST_DEVICE
475
+ cutlass::tfloat32_t operator "" _tf32(unsigned long long int x) {
476
+ return cutlass::tfloat32_t(int(x));
477
+ }
478
+
479
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/thread/matrix.h ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Defines a matrix object intended for storing data in registers and operations within
33
+ a CUDA thread.
34
+ */
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/array.h"
39
+ #include "cutlass/matrix_coord.h"
40
+
41
+ namespace cutlass {
42
+ namespace thread {
43
+
44
+ /////////////////////////////////////////////////////////////////////////////////////////////////
45
+
46
+ /// Per-thread matrix object storing a packed matrix
47
+ template <
48
+ typename Element,
49
+ int Rows,
50
+ int Columns,
51
+ typename Layout = layout::RowMajor
52
+ >
53
+ class Matrix : public Array<Element, Rows * Columns> {
54
+ public:
55
+
56
+ // Verify layout refers to a rank=2 matrix.
57
+ static_assert(
58
+ Layout::kRank == 2,
59
+ "Layout type must refer to a rank=2 matrix");
60
+
61
+ /// Base type
62
+ using Base = Array<Element, Rows * Columns>;
63
+
64
+ /// Element type
65
+ using Element = Element_;
66
+
67
+ /// Number of rows
68
+ static int const kRows = Rows;
69
+
70
+ /// Number of columns
71
+ static int const kColumns = Columns;
72
+
73
+ /// Layout within the array
74
+ using Layout = Layout_;
75
+
76
+ /// Reference type to an element
77
+ using Reference = Element &;
78
+
79
+ /// Logical rank of tensor index space
80
+ static int const kRank = 2;
81
+
82
+ /// Index type
83
+ using Index = typename Layout::Index;
84
+
85
+ /// Long index used for pointer offsets
86
+ using LongIndex = typename Layout::LongIndex;
87
+
88
+ /// Coordinate in logical tensor space
89
+ using TensorCoord = typename Layout::TensorCoord;
90
+
91
+ /// Stride type
92
+ using Stride = typename Layout::Stride;
93
+
94
+ /// TensorRef to matrix object
95
+ using TensorRef = TensorRef<Element, kRank, Layout>;
96
+
97
+ /// TensorRef to constant matrix object
98
+ using ConstTensorRef = typename TensorRef::ConstTensorRef;
99
+
100
+ /// TensorRef to matrix object
101
+ using TensorView = TensorView<Element, kRank, Layout>;
102
+
103
+ /// TensorRef to constant matrix object
104
+ using ConstTensorView = typename TensorView::ConstTensorView;
105
+
106
+ /// Diagonal vector
107
+ using Diagonal = Vector<Element, __NV_STD_MIN(kRows, kColumns)>;
108
+
109
+ private:
110
+
111
+
112
+ public:
113
+
114
+ //
115
+ // Methods
116
+ //
117
+
118
+ /// Returns the size of the object
119
+ CUTLASS_HOST_DEVICE
120
+ static MatrixCoord extent() {
121
+ return make_Coord(kRows, kColumns);
122
+ }
123
+
124
+ /// Returns the layout object
125
+ CUTLASS_HOST_DEVICE
126
+ static Layout layout() {
127
+ return Layout::packed(extent());
128
+ }
129
+
130
+ /// Ctor
131
+ CUTLASS_HOST_DEVICE
132
+ Matrix() { }
133
+
134
+ /// Ctor
135
+ CUTLASS_HOST_DEVICE
136
+ Matrix(Diagonal const &diag) {
137
+ }
138
+
139
+ /// Returns a TensorRef pointing to the first element of the tensor.
140
+ CUTLASS_HOST_DEVICE
141
+ TensorRef ref() {
142
+ return TensorRef(this->data(), layout());
143
+ }
144
+
145
+ /// Returns a TensorRef pointing to the first element of the tensor.
146
+ CUTLASS_HOST_DEVICE
147
+ ConstTensorRef const_ref() const {
148
+ return ConstTensorRef(this->data(), layout());
149
+ }
150
+
151
+ /// Returns a TensorRef pointing to the first element of the tensor.
152
+ CUTLASS_HOST_DEVICE
153
+ TensorView view() {
154
+ return TensorView(ref(), extent());
155
+ }
156
+
157
+ /// Returns a TensorView to const data
158
+ CUTLASS_HOST_DEVICE
159
+ ConstTensorView const_view() const {
160
+ return ConstTensorView(const_ref(), extent());
161
+ }
162
+
163
+ /// Returns a reference to the element at a given Coord
164
+ CUTLASS_HOST_DEVICE
165
+ Reference at(MatrixCoord const& coord) const {
166
+ typename Base::size_type offset_(layout().offset(coord));
167
+ return Base::at(offset_);
168
+ }
169
+
170
+ /// Returns the number of scalar elements needed to store tensor.
171
+ CUTLASS_HOST_DEVICE
172
+ LongIndex capacity() const {
173
+ return LongIndex(Base::size());
174
+ }
175
+ };
176
+
177
+ /////////////////////////////////////////////////////////////////////////////////////////////////
178
+
179
+ /// Column vector defined as a matrix with exactly one column
180
+ template <
181
+ typename Element,
182
+ int Rows,
183
+ typename Layout = layout::ColumnMajor
184
+ >
185
+ using ColumnVector = Matrix<Element, Rows, 1, Layout>;
186
+
187
+ /// Row vector defined as a matrix with exactly one row
188
+ template <
189
+ typename Element,
190
+ int Columns,
191
+ typename Layout = layout::RowMajor
192
+ >
193
+ using RowVector = Matrix<Element, 1, Columns, Layout>;
194
+
195
+ /////////////////////////////////////////////////////////////////////////////////////////////////
196
+
197
+ } // namespace thread
198
+ } // namespace cutlass
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/trace.h ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Helpers for optionally tracing through code when debugging.
33
+
34
+ This file is to be included after all other headers.
35
+ */
36
+
37
+ #pragma once
38
+
39
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
40
+
41
+ // Tracing options
42
+ #ifndef CUTLASS_DEBUG_TRACE_LEVEL
43
+ #define CUTLASS_DEBUG_TRACE_LEVEL 0
44
+ #endif
45
+
46
+ #if CUTLASS_DEBUG_TRACE_LEVEL
47
+ #include <iostream>
48
+ #include "cutlass/core_io.h"
49
+ #if defined(__CUDA_ARCH__)
50
+ #define CUTLASS_TRACE_HOST(x)
51
+ #else
52
+ #define CUTLASS_TRACE_HOST(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; }
53
+ #endif
54
+ #else
55
+ #define CUTLASS_TRACE_HOST(x)
56
+ #endif
57
+
58
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
59
+
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp ADDED
@@ -0,0 +1,754 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates implementing how threads are mapped to a given tile.
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cute/arch/mma_sm90_gmma.hpp"
38
+ /////////////////////////////////////////////////////////////////////////////////////////////////
39
+
40
+ namespace cutlass {
41
+ namespace transform {
42
+ namespace collective {
43
+
44
+ /////////////////////////////////////////////////////////////////////////////////////////////////
45
+
46
+ namespace detail {
47
+ using namespace cute;
48
+
49
+ template <bool Transpose, class SmemLayoutAtom, class ElementType>
50
+ constexpr auto
51
+ gmma_smem_transpose_or_passthrough() {
52
+ if constexpr (Transpose) {
53
+ if constexpr (cute::is_same_v<GMMA::Layout_MN_SW128_Atom<ElementType>, SmemLayoutAtom>) {
54
+ return GMMA::Layout_K_SW128_Atom<ElementType>{};
55
+ }
56
+ else if constexpr (cute::is_same_v<GMMA::Layout_MN_SW64_Atom<ElementType>, SmemLayoutAtom>) {
57
+ return GMMA::Layout_K_SW64_Atom<ElementType>{};
58
+ }
59
+ else if constexpr (cute::is_same_v<GMMA::Layout_MN_SW32_Atom<ElementType>, SmemLayoutAtom>) {
60
+ return GMMA::Layout_K_SW32_Atom<ElementType>{};
61
+ }
62
+ else if constexpr (cute::is_same_v<GMMA::Layout_MN_INTER_Atom<ElementType>, SmemLayoutAtom>) {
63
+ return GMMA::Layout_K_INTER_Atom<ElementType>{};
64
+ }
65
+ else {
66
+ static_assert(cutlass::detail::dependent_false<SmemLayoutAtom>, "Unsupported Layout_SW_Atom for B SMEM transposition");
67
+ }
68
+ }
69
+ else {
70
+ return SmemLayoutAtom{};
71
+ }
72
+ }
73
+
74
+ template <class SmemCopyAtom, class ElementType>
75
+ constexpr auto
76
+ use_universal_transposition() {
77
+ if constexpr (sizeof(ElementType) == 1) {
78
+ return !cute::is_same_v<GMMA::Layout_MN_SW128_Atom<ElementType>, SmemCopyAtom>;
79
+ }
80
+ else if constexpr (sizeof(ElementType) == 4){
81
+ // Only universal transposition can handle SW64 and Non swizzle SMEM layout
82
+ if constexpr (cute::is_same_v<GMMA::Layout_MN_SW64_Atom<ElementType>, SmemCopyAtom> ||
83
+ cute::is_same_v<GMMA::Layout_MN_INTER_Atom<ElementType>, SmemCopyAtom>) {
84
+ return true;
85
+ }
86
+ else {
87
+ return false;
88
+ }
89
+ }
90
+ else {
91
+ static_assert(cutlass::detail::dependent_false<ElementType>, "Unsupported ElementType for B SMEM transposition");
92
+ }
93
+ }
94
+
95
+ template<
96
+ class TiledMma_,
97
+ class SmemLayoutB_,
98
+ class SmemLayoutAtomB_,
99
+ class ElementB_>
100
+ class NoTranspositionOperandB {
101
+ public:
102
+ using TiledMma = TiledMma_;
103
+ using SmemLayoutB = SmemLayoutB_;
104
+ using SmemLayoutAtomB = SmemLayoutAtomB_;
105
+ using ElementB = ElementB_;
106
+
107
+ constexpr CUTLASS_HOST_DEVICE
108
+ NoTranspositionOperandB(
109
+ int,
110
+ int,
111
+ TiledMma,
112
+ SmemLayoutB,
113
+ SmemLayoutAtomB,
114
+ ElementB) { }
115
+
116
+ template <
117
+ class TensorSmemB,
118
+ class TensorTransposedSmemB>
119
+ CUTLASS_DEVICE void operator()(
120
+ TensorSmemB const&,
121
+ TensorTransposedSmemB const&,
122
+ int, int) { }
123
+
124
+ CUTLASS_DEVICE void synchronize(int) { }
125
+
126
+ CUTLASS_DEVICE void synchronize() { }
127
+
128
+ template <
129
+ class TensorSmemB,
130
+ class TensorTransposedSmemB>
131
+ CUTLASS_DEVICE void transpose(
132
+ TensorSmemB const&,
133
+ TensorTransposedSmemB const&,
134
+ int) { }
135
+ };
136
+
137
+ template<
138
+ class TiledMma_,
139
+ class SmemLayoutB_,
140
+ class SmemLayoutAtomB_,
141
+ class ElementB_>
142
+ class UniversalTranspositionOperandB {
143
+ public:
144
+ using TiledMma = TiledMma_;
145
+ using SmemLayoutB = SmemLayoutB_;
146
+ using SmemLayoutAtomB = SmemLayoutAtomB_;
147
+ using ElementB = ElementB_;
148
+
149
+ constexpr CUTLASS_HOST_DEVICE
150
+ UniversalTranspositionOperandB(
151
+ int warp_idx_,
152
+ int warp_group_thread_idx_,
153
+ TiledMma,
154
+ SmemLayoutB,
155
+ SmemLayoutAtomB,
156
+ ElementB)
157
+ : warp_idx(warp_idx_)
158
+ , warp_group_thread_idx(warp_group_thread_idx_) { }
159
+
160
+ template <
161
+ class TensorSmemB,
162
+ class TensorTransposedSmemB>
163
+ CUTLASS_DEVICE void operator()(
164
+ TensorSmemB const& sB,
165
+ TensorTransposedSmemB const& gmma_sB,
166
+ int read_stage, int current_step) {
167
+ if (current_step > 0) {
168
+ return;
169
+ }
170
+
171
+ constexpr int NumMathWarpGroup = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup;
172
+ static_assert(NumMathWarpGroup == 1 ||
173
+ (!detail::use_universal_transposition<SmemLayoutAtomB, ElementB>() && NumMathWarpGroup == 2),
174
+ "Wrong math warp group number for TransposeB");
175
+ constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K.
176
+
177
+ constexpr int BytesPerSmemSwizzleUnit = 16;
178
+ constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB);
179
+ //////////////////////////////////////////////////////////////////////////////////////////////////////////////
180
+ /// Universal transposition, need warp_group sync between load and store.
181
+ /// The number of reg used depends on the input elementB.
182
+ //////////////////////////////////////////////////////////////////////////////////////////////////////////////
183
+ /*
184
+ In one copy step, a warp group would load WarpgroupTileSize * WarpgroupTileSize tile then store to transposed location.
185
+ In warp_group_tile, each warp holds Four WarpTileSize x WarpTileSize elements:
186
+ K
187
+ ------------
188
+ | W0 W1 W2 W3 ---
189
+ | W0 W1 W2 W3 |
190
+ | W0 W1 W2 W3 | --> Copy Step 0
191
+ | W0 W1 W2 W3 ---
192
+ ....
193
+ | W0 W1 W2 W3 ---
194
+ | W0 W1 W2 W3 |
195
+ | W0 W1 W2 W3 | --> Copy Step n
196
+ | W0 W1 W2 W3 ---
197
+ */
198
+ static_assert((NumThreadsPerWarpGroup % WarpThreadShapeN == 0), "Unsupported warp thread layout.");
199
+ constexpr auto WarpgroupThreadLayout = make_layout(make_shape(Int<WarpThreadShapeN>{}, Int<NumThreadsPerWarpGroup / WarpThreadShapeN>{}));
200
+
201
+ // Get copy tile and partition to each thread
202
+ auto sB_tiled_copy = make_tiled_copy(
203
+ Copy_Atom<DefaultCopy, ElementB>{},
204
+ WarpgroupThreadLayout, // thr_layout
205
+ Layout<_1>{} // val_layout
206
+ );
207
+ static_assert(size(sB_tiled_copy) == size(TiledMma{}), "Wrong thread number in TiledCopy.");
208
+
209
+ auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx);
210
+ Tensor tCsB = sB_thr_copy.partition_S( sB(_,_,read_stage)); // (CPY, CPY_N, CPY_K)
211
+ Tensor tCsB_transposed = sB_thr_copy.partition_D(gmma_sB(_,_,read_stage)); // (CPY, CPY_N, CPY_K)
212
+
213
+ // Divide partitioned tile to limit register usage
214
+ constexpr int CopySteps = size<0>(SmemLayoutB{}) / WarpgroupTileSize;
215
+ constexpr auto CopyTileShape = make_shape(size<0>(tCsB), Int< size<1>(tCsB) / CopySteps >{}, size<2>(tCsB));
216
+ static_assert(size<1>(tCsB) % CopySteps == 0, "CopySteps must evenly divide rank 1 size of partitioned SMEM.");
217
+
218
+ Tensor tCsB_copy_tile = zipped_divide(tCsB, CopyTileShape);
219
+ Tensor tCsB_copy_tile_transposed = zipped_divide(tCsB_transposed, CopyTileShape);
220
+ auto transpose_fragment = make_fragment_like(tCsB_copy_tile(_,_0{}));
221
+
222
+ CUTLASS_PRAGMA_NO_UNROLL
223
+ for (int step = 0; step < CopySteps; ++step) {
224
+ copy(sB_tiled_copy, tCsB_copy_tile(_,step), transpose_fragment);
225
+
226
+ // Make sure all elements are read before being overwritten
227
+ __syncthreads();
228
+
229
+ copy(sB_tiled_copy, transpose_fragment, tCsB_copy_tile_transposed(_,step));
230
+ }
231
+ }
232
+
233
+ CUTLASS_DEVICE void synchronize(int step) {
234
+ if (step == 0) {
235
+ // SMEM fence to make sure B is transposed before math
236
+ cutlass::arch::fence_view_async_shared();
237
+ cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier);
238
+ }
239
+ }
240
+
241
+ CUTLASS_DEVICE void synchronize() {
242
+ // SMEM fence to make sure B is transposed before math
243
+ cutlass::arch::fence_view_async_shared();
244
+ cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier);
245
+ }
246
+
247
+ template <
248
+ class TensorSmemB,
249
+ class TensorTransposedSmemB>
250
+ CUTLASS_DEVICE void transpose(
251
+ TensorSmemB const& sB,
252
+ TensorTransposedSmemB const& gmma_sB,
253
+ int read_stage) {
254
+
255
+ this->operator()(sB, gmma_sB, read_stage, 0);
256
+ synchronize();
257
+
258
+ }
259
+
260
+ private:
261
+ const int warp_idx;
262
+ const int warp_group_thread_idx;
263
+ };
264
+
265
+ template<
266
+ class TiledMma_,
267
+ class SmemLayoutB_,
268
+ class SmemLayoutAtomB_,
269
+ class ElementB_>
270
+ class AsyncTranspositionOperandB {
271
+ public:
272
+
273
+ using TiledMma = TiledMma_;
274
+ using SmemLayoutB = SmemLayoutB_;
275
+ using SmemLayoutAtomB = SmemLayoutAtomB_;
276
+ using ElementB = ElementB_;
277
+
278
+ static constexpr int Steps = 2;
279
+ static constexpr int NumMathWarpGroup = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup;
280
+ static constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup;
281
+ static_assert(NumMathWarpGroup <= 2,
282
+ "Wrong math warp group number for TransposeB");
283
+ static constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K.
284
+ static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp;
285
+
286
+ static constexpr int BytesPerSmemSwizzleUnit = 16;
287
+ static constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB);
288
+ static constexpr int WarpThreadShapeK = NumThreadsPerWarp / WarpThreadShapeN;
289
+ static constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1);
290
+
291
+ static constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile;
292
+ static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invalid warp thread shape." );
293
+ static constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step.
294
+ static constexpr int64_t WarpTileNCoordLUT = 06723763275316420;
295
+ static constexpr int64_t WarpTileKCoordLUT = 05410541064206420;
296
+ static constexpr int NumStepsEncoded = 4; // Only encoding first 4 steps into LUT.
297
+ static constexpr int MaskPerStep = 07; // Each step is encoded into 3bits,
298
+ static constexpr int NumBitsPerStep = 3;
299
+ static constexpr int MaskPerWarp = 07777; // Each warp has 4 steps(12 bits)
300
+ static constexpr int NumBitsPerWarp = 12;
301
+ // Number of warp_group_tiles
302
+ static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0,
303
+ "Copy size must evenly divide SMEM tile.");
304
+ static constexpr int WarpgroupTileNum = size<0>(SmemLayoutB{}) / WarpgroupTileSize;
305
+
306
+ static_assert(size<2>(typename TiledMma::AtomShape_MNK{}) <= WarpThreadShapeK,
307
+ "Need to be able to transpose first k-block in the first step");
308
+
309
+ constexpr CUTLASS_HOST_DEVICE
310
+ AsyncTranspositionOperandB(
311
+ int warp_idx_,
312
+ int warp_group_thread_idx_,
313
+ TiledMma,
314
+ SmemLayoutB,
315
+ SmemLayoutAtomB,
316
+ ElementB)
317
+ : warp_idx(warp_idx_)
318
+ , warp_group_thread_idx(warp_group_thread_idx_)
319
+ , warp_idx_in_warp_group(warp_idx_ % NumWarpsPerWarpGroup)
320
+ , current_warp_tile_n_coord_LUT((WarpTileNCoordLUT >> ((warp_idx_
321
+ % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp)
322
+ , current_warp_tile_k_coord_LUT((WarpTileKCoordLUT >> ((warp_idx_
323
+ % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) { }
324
+
325
+ template <
326
+ class TensorSmemB,
327
+ class TensorTransposedSmemB>
328
+ CUTLASS_DEVICE void operator()(
329
+ TensorSmemB const& sB,
330
+ TensorTransposedSmemB const& gmma_sB,
331
+ int read_stage, int current_step)
332
+ {
333
+ if (current_step >= StepsPerWarpGroup) {
334
+ return;
335
+ }
336
+
337
+ static constexpr auto WarpThreadLayout = make_layout(make_shape(Int<WarpThreadShapeN>{}, Int<WarpThreadShapeK>{}));
338
+ //////////////////////////////////////////////////////////////////////////////////////////////////////////////
339
+ /// A warp group uses 2 steps to transpose the whole WarpgroupTileSize x WarpgroupTileSize.
340
+ /// In each step, one warp would hold two warp_tiles.
341
+ /// Step 0: Step 1:
342
+ /// W0 W1 W2 W3 -- -- -- --
343
+ /// W1 W0 -- -- -- -- W3 W2
344
+ /// W2 -- -- -- -- W3 W0 W1
345
+ /// W3 -- -- -- -- W2 W1 W0
346
+ ///
347
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////////
348
+ ///
349
+ /// Fully static coord LUT to avoid extra register use.
350
+ /// [warp_id][step][warp_tile][n / k]
351
+ /// Step 0 Step 1 Step 2 Step 3 Step 4 Step 5 Step 6 Step 7
352
+ /// {{{0,0}, {1,1}}, {{2,2}, {3,3}}, {{4,4}, {5,5}}, {{6,6}, {7,7}}, {{4,0}, {0,4}}, {{4,1}, {1,4}}, {{4,2}, {2,4}}, {{4,3}, {3,4}}}, // W0
353
+ /// {{{1,0}, {0,1}}, {{3,2}, {2,3}}, {{5,4}, {4,5}}, {{7,6}, {6,7}}, {{5,0}, {0,5}}, {{5,1}, {1,5}}, {{5,2}, {2,5}}, {{5,3}, {3,5}}}, // W1
354
+ /// {{{2,0}, {0,2}}, {{3,1}, {1,3}}, {{6,4}, {4,6}}, {{7,5}, {5,7}}, {{6,0}, {0,6}}, {{6,1}, {1,6}}, {{6,2}, {2,6}}, {{6,3}, {3,6}}}, // W2
355
+ /// {{{3,0}, {0,3}}, {{2,1}, {1,2}}, {{7,4}, {4,7}}, {{6,5}, {5,6}}, {{7,0}, {0,7}}, {{7,1}, {1,7}}, {{7,2}, {2,7}}, {{7,3}, {3,7}}}, // W3
356
+ ///
357
+ /// Encoding the coord of warp tile0 into two int64_t values.
358
+ /// Only encoding Step 0 ~ Step 4, since Step 5 ~ Step 7 have a straightforward pattern.
359
+ /// Only encoding warp tile0, since the coords of warp tile1 could be easily deduced from warp tile0.
360
+ /// The 2-step transposition and the 8-step transposition share the same encoding.
361
+ ///
362
+ //////////////////////////////////////////////////////////////////////////////////////////////////////////////
363
+
364
+ // Divide entire SMEM to multiple warp_tiles
365
+ constexpr auto WarpTileShape = make_shape(Int<WarpTileSize>(), Int<WarpTileSize>());
366
+ Tensor s_tile = zipped_divide( sB(_,_,read_stage), WarpTileShape);
367
+ Tensor s_tile_transposed = zipped_divide(gmma_sB(_,_,read_stage), WarpTileShape);
368
+
369
+ // Get copy tile
370
+ auto sB_tiled_copy = make_tiled_copy(
371
+ Copy_Atom<DefaultCopy, ElementB>{},
372
+ WarpThreadLayout, // thr_layout
373
+ Layout<_1>{} // val_layout
374
+ );
375
+
376
+ static_assert(size(sB_tiled_copy) * NumWarpsPerWarpGroup == size(TiledMma{}) / NumMathWarpGroup, "Wrong thread number in TiledCopy.");
377
+ auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx % NumThreadsPerWarp); // slice based on lane_idx
378
+
379
+ // Construct fragments for transposition
380
+ Tensor tmp_tCsB = sB_thr_copy.partition_S(flatten(s_tile(_, make_coord(_0{}, _0{}))));
381
+ decltype(make_fragment_like(tmp_tCsB)) transpose_fragments[TilesPerWarp] = {
382
+ make_fragment_like(tmp_tCsB),
383
+ make_fragment_like(tmp_tCsB)
384
+ };
385
+
386
+ [[maybe_unused]] int step = current_step * NumMathWarpGroup;
387
+ if constexpr (NumMathWarpGroup == 2) {
388
+ // For 2 math warpgroup, warp idx4~7 is 1st warp group and 8~9 is 2nd, so decide if 2nd warpgroup need warp idx divide 8.
389
+ step += warp_idx / (NumWarpsPerWarpGroup * 2);
390
+ }
391
+
392
+ int tmp_warp_tile_n_coord_LUT = current_warp_tile_n_coord_LUT >> (NumBitsPerStep * current_step);
393
+ int tmp_warp_tile_k_coord_LUT = current_warp_tile_k_coord_LUT >> (NumBitsPerStep * current_step);
394
+
395
+ if constexpr (NumMathWarpGroup == 2) {
396
+ tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2));
397
+ tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2));
398
+ }
399
+
400
+ // decoding the warp tile coord.
401
+ int warp_tile0_n, warp_tile0_k;
402
+ if constexpr (StepsPerWarpGroup <= NumStepsEncoded) {
403
+ warp_tile0_n = tmp_warp_tile_n_coord_LUT & MaskPerStep;
404
+ warp_tile0_k = tmp_warp_tile_k_coord_LUT & MaskPerStep;
405
+ } else {
406
+ warp_tile0_n = step < NumStepsEncoded ? (tmp_warp_tile_n_coord_LUT & MaskPerStep) : 4 + warp_idx_in_warp_group;
407
+ warp_tile0_k = step < NumStepsEncoded ? (tmp_warp_tile_k_coord_LUT & MaskPerStep) : step - 4;
408
+ }
409
+
410
+ int warp_tile1_n = warp_tile0_n == warp_tile0_k ? warp_tile0_n + 1 : warp_tile0_k;
411
+ int warp_tile1_k = warp_tile0_n == warp_tile0_k ? warp_tile0_k + 1 : warp_tile0_n;
412
+
413
+ CUTLASS_PRAGMA_UNROLL
414
+ for (int warp_group_tile = 0; warp_group_tile < WarpgroupTileNum; ++warp_group_tile) {
415
+
416
+ static_assert(TilesPerWarp == 2);
417
+
418
+ // [warp_tile][n/k]
419
+ const int warp_tile_coord[TilesPerWarp][2] = {
420
+ // n k
421
+ {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile0_n, warp_tile0_k}, // warp_tile 0
422
+ {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile1_n, warp_tile1_k} // warp_tile 1
423
+ };
424
+
425
+ CUTLASS_PRAGMA_UNROLL
426
+ for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) {
427
+ Tensor tCsB = sB_thr_copy.partition_S(
428
+ flatten(s_tile(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1])))
429
+ ); // (CPY, CPY_N, CPY_K)
430
+
431
+ copy(sB_tiled_copy, tCsB, transpose_fragments[warp_tile]);
432
+ }
433
+
434
+ // Make sure elements in two 8x8 warp tiles are all consumed
435
+ __syncwarp();
436
+
437
+ CUTLASS_PRAGMA_UNROLL
438
+ for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) {
439
+ Tensor tCsB_transposed = sB_thr_copy.partition_D(
440
+ flatten(s_tile_transposed(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1])))
441
+ ); // (CPY, CPY_N, CPY_K)
442
+ copy(sB_tiled_copy, transpose_fragments[warp_tile], tCsB_transposed);
443
+ }
444
+
445
+ } // loop warp_group_tile
446
+ }
447
+
448
+ CUTLASS_DEVICE void synchronize(int step) {
449
+ if (step < StepsPerWarpGroup) {
450
+ // SMEM fence to make sure B is transposed before math
451
+ cutlass::arch::fence_view_async_shared();
452
+ cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier);
453
+ }
454
+ }
455
+
456
+ CUTLASS_DEVICE void synchronize() {
457
+ cutlass::arch::fence_view_async_shared();
458
+ cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier);
459
+ }
460
+
461
+ template <
462
+ class TensorSmemB,
463
+ class TensorTransposedSmemB>
464
+ CUTLASS_DEVICE void transpose(
465
+ TensorSmemB const& sB,
466
+ TensorTransposedSmemB const& gmma_sB,
467
+ int read_stage) {
468
+
469
+ CUTLASS_PRAGMA_UNROLL
470
+ for(int i = 0; i < StepsPerWarpGroup; ++i) {
471
+ this->operator()(sB, gmma_sB, read_stage, i);
472
+ }
473
+ synchronize();
474
+
475
+ }
476
+ private:
477
+ const int warp_idx;
478
+ const int warp_group_thread_idx;
479
+ const int warp_idx_in_warp_group;
480
+ const int current_warp_tile_n_coord_LUT;
481
+ const int current_warp_tile_k_coord_LUT;
482
+ };
483
+
484
+ template<
485
+ class TiledMma_,
486
+ class SmemLayoutB_,
487
+ class SmemLayoutAtomB_,
488
+ class ElementB_>
489
+ class AsyncTranspositionOperandB_1BElementB {
490
+ public:
491
+
492
+ static_assert(sizeof(ElementB_) == 1);
493
+
494
+ using TiledMma = TiledMma_;
495
+ using SmemLayoutB = SmemLayoutB_;
496
+ using SmemLayoutAtomB = SmemLayoutAtomB_;
497
+ using ElementB = ElementB_;
498
+
499
+ static constexpr int Steps = 8;
500
+ static constexpr int NumMathWarpGroup = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup;
501
+ static constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup;
502
+ static_assert(NumMathWarpGroup <= 2,
503
+ "Wrong math warp group number for TransposeB");
504
+ static constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K.
505
+ static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp;
506
+
507
+ static constexpr int BytesPerSmemSwizzleUnit = 16;
508
+ static constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB);
509
+ static constexpr int WarpThreadShapeK = NumThreadsPerWarp / WarpThreadShapeN;
510
+ static constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1);
511
+
512
+ static constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile;
513
+ static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invalid warp thread shape." );
514
+ static constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step.
515
+ static constexpr int64_t WarpTileNCoordLUT = 06723763275316420;
516
+ static constexpr int64_t WarpTileKCoordLUT = 05410541064206420;
517
+ static constexpr int NumStepsEncoded = 4; // Only encoding first 4 steps into LUT.
518
+ static constexpr int MaskPerStep = 07; // Each step is encoded into 3bits,
519
+ static constexpr int NumBitsPerStep = 3;
520
+ static constexpr int MaskPerWarp = 07777; // Each warp has 4 steps(12 bits)
521
+ static constexpr int NumBitsPerWarp = 12;
522
+ // Number of warp_group_tiles
523
+ static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0,
524
+ "Copy size must evenly divide SMEM tile.");
525
+ static constexpr int WarpgroupTileNum = size<0>(SmemLayoutB{}) / WarpgroupTileSize;
526
+
527
+ constexpr CUTLASS_HOST_DEVICE
528
+ AsyncTranspositionOperandB_1BElementB(
529
+ int warp_idx_,
530
+ int warp_group_thread_idx_,
531
+ TiledMma,
532
+ SmemLayoutB,
533
+ SmemLayoutAtomB,
534
+ ElementB)
535
+ : warp_idx(warp_idx_)
536
+ , warp_group_thread_idx(warp_group_thread_idx_)
537
+ , warp_idx_in_warp_group(warp_idx_ % NumWarpsPerWarpGroup)
538
+ , current_warp_tile_n_coord_LUT((WarpTileNCoordLUT >> ((warp_idx_
539
+ % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp)
540
+ , current_warp_tile_k_coord_LUT((WarpTileKCoordLUT >> ((warp_idx_
541
+ % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) { }
542
+
543
+ template <
544
+ class TensorSmemB,
545
+ class TensorTransposedSmemB>
546
+ CUTLASS_DEVICE void operator()(
547
+ TensorSmemB const& sB,
548
+ TensorTransposedSmemB const& gmma_sB,
549
+ int read_stage, int current_step)
550
+ {
551
+ if (current_step > 0) {
552
+ return;
553
+ }
554
+
555
+ constexpr auto WarpThreadLayout = make_layout(make_shape(Int<WarpThreadShapeN>{}, Int<WarpThreadShapeK>{}));
556
+ //////////////////////////////////////////////////////////////////////////////////////////////////////////////
557
+ /// A warp group uses 8 steps to transpose the whole WarpgroupTileSize x WarpgroupTileSize.
558
+ /// Divide a warp_group_tile into 8x8 warp_tiles to further reduce the reg usage.
559
+ /// Step 0: Step 1: Step 2: Step 3:
560
+ /// W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
561
+ /// W1 W0 -- -- -- -- -- -- -- -- W3 W2 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
562
+ /// W2 -- -- -- -- -- -- -- -- W3 W0 W1 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
563
+ /// W3 -- -- -- -- -- -- -- -- W2 W1 W0 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
564
+ /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- --
565
+ /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W1 W0 -- -- -- -- -- -- -- -- W3 W2
566
+ /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W2 -- -- -- -- -- -- -- -- W3 W0 W1
567
+ /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W3 -- -- -- -- -- -- -- -- W2 W1 W0
568
+ ///
569
+ /// Step 4: Step 5: Step 6: Step 7:
570
+ /// -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
571
+ /// -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
572
+ /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- --
573
+ /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3
574
+ /// W0 -- -- -- -- -- -- -- -- W0 -- -- -- -- -- -- -- -- W0 -- -- -- -- -- -- -- -- W0 -- -- -- --
575
+ /// W1 -- -- -- -- -- -- -- -- W1 -- -- -- -- -- -- -- -- W1 -- -- -- -- -- -- -- -- W1 -- -- -- --
576
+ /// W2 -- -- -- -- -- -- -- -- W2 -- -- -- -- -- -- -- -- W2 -- -- -- -- -- -- -- -- W2 -- -- -- --
577
+ /// W3 -- -- -- -- -- -- -- -- W3 -- -- -- -- -- -- -- -- W3 -- -- -- -- -- -- -- -- W3 -- -- -- --
578
+ ///
579
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////////
580
+ ///
581
+ /// Fully static coord LUT to avoid extra register use.
582
+ /// [warp_id][step][warp_tile][n / k]
583
+ /// Step 0 Step 1 Step 2 Step 3 Step 4 Step 5 Step 6 Step 7
584
+ /// {{{0,0}, {1,1}}, {{2,2}, {3,3}}, {{4,4}, {5,5}}, {{6,6}, {7,7}}, {{4,0}, {0,4}}, {{4,1}, {1,4}}, {{4,2}, {2,4}}, {{4,3}, {3,4}}}, // W0
585
+ /// {{{1,0}, {0,1}}, {{3,2}, {2,3}}, {{5,4}, {4,5}}, {{7,6}, {6,7}}, {{5,0}, {0,5}}, {{5,1}, {1,5}}, {{5,2}, {2,5}}, {{5,3}, {3,5}}}, // W1
586
+ /// {{{2,0}, {0,2}}, {{3,1}, {1,3}}, {{6,4}, {4,6}}, {{7,5}, {5,7}}, {{6,0}, {0,6}}, {{6,1}, {1,6}}, {{6,2}, {2,6}}, {{6,3}, {3,6}}}, // W2
587
+ /// {{{3,0}, {0,3}}, {{2,1}, {1,2}}, {{7,4}, {4,7}}, {{6,5}, {5,6}}, {{7,0}, {0,7}}, {{7,1}, {1,7}}, {{7,2}, {2,7}}, {{7,3}, {3,7}}}, // W3
588
+ ///
589
+ /// Encoding the coord of warp tile0 into two int64_t values.
590
+ /// Only encoding Step 0 ~ Step 4, since Step 5 ~ Step 7 have a straightforward pattern.
591
+ /// Only encoding warp tile0, since the coords of warp tile1 could be easily deduced from warp tile0.
592
+ /// The 2-step transposition and the 8-step transposition share the same encoding.
593
+ ///
594
+ //////////////////////////////////////////////////////////////////////////////////////////////////////////////
595
+
596
+ // Divide entire SMEM to multiple warp_tiles
597
+ constexpr auto WarpTileShape = make_shape(Int<WarpTileSize>(), Int<WarpTileSize>());
598
+ Tensor s_tile = zipped_divide( sB(_,_,read_stage), WarpTileShape);
599
+ Tensor s_tile_transposed = zipped_divide(gmma_sB(_,_,read_stage), WarpTileShape);
600
+
601
+ // Get copy tile
602
+ auto sB_tiled_copy = make_tiled_copy(
603
+ Copy_Atom<DefaultCopy, ElementB>{},
604
+ WarpThreadLayout, // thr_layout
605
+ Layout<_1>{} // val_layout
606
+ );
607
+ static_assert(size(sB_tiled_copy) * NumWarpsPerWarpGroup == size(TiledMma{}) / NumMathWarpGroup, "Wrong thread number in TiledCopy.");
608
+ auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx % NumThreadsPerWarp); // slice based on lane_idx
609
+
610
+ // Construct fragments for transposition
611
+ Tensor tmp_tCsB = sB_thr_copy.partition_S(flatten(s_tile(_, make_coord(_0{}, _0{}))));
612
+ decltype(make_fragment_like(tmp_tCsB)) transpose_fragments[TilesPerWarp] = {
613
+ make_fragment_like(tmp_tCsB),
614
+ make_fragment_like(tmp_tCsB)
615
+ };
616
+
617
+ CUTLASS_PRAGMA_NO_UNROLL
618
+ for (int warp_group_tile = 0; warp_group_tile < WarpgroupTileNum; ++warp_group_tile) {
619
+ int tmp_warp_tile_n_coord_LUT = current_warp_tile_n_coord_LUT;
620
+ int tmp_warp_tile_k_coord_LUT = current_warp_tile_k_coord_LUT;
621
+ constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup;
622
+
623
+ if constexpr (NumMathWarpGroup == 2) {
624
+ tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2));
625
+ tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2));
626
+ }
627
+
628
+ CUTLASS_PRAGMA_NO_UNROLL
629
+ for (int step_per_warp_group = 0; step_per_warp_group < StepsPerWarpGroup; ++step_per_warp_group) {
630
+ // For 2 math warpgroup, warp idx4~7 is 1st warp group and 8~9 is 2nd, so decide if 2nd warpgroup need warp idx divide 8.
631
+ int step = step_per_warp_group * NumMathWarpGroup + warp_idx / (NumWarpsPerWarpGroup * 2);
632
+ // decoding the warp tile coord.
633
+ int warp_tile0_n = step < NumStepsEncoded ? (tmp_warp_tile_n_coord_LUT & MaskPerStep) : 4 + warp_idx_in_warp_group;
634
+ int warp_tile0_k = step < NumStepsEncoded ? (tmp_warp_tile_k_coord_LUT & MaskPerStep) : step - 4;
635
+ int warp_tile1_n = warp_tile0_n == warp_tile0_k ? warp_tile0_n + 1 : warp_tile0_k;
636
+ int warp_tile1_k = warp_tile0_n == warp_tile0_k ? warp_tile0_k + 1 : warp_tile0_n;
637
+
638
+ tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep;
639
+ tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep;
640
+
641
+ static_assert(TilesPerWarp == 2);
642
+
643
+ // [warp_tile][n/k]
644
+ const int warp_tile_coord[TilesPerWarp][2] = {
645
+ // n k
646
+ {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile0_n, warp_tile0_k}, // warp_tile 0
647
+ {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile1_n, warp_tile1_k} // warp_tile 1
648
+ };
649
+
650
+ CUTLASS_PRAGMA_UNROLL
651
+ for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) {
652
+ Tensor tCsB = sB_thr_copy.partition_S(
653
+ flatten(s_tile(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1])))
654
+ ); // (CPY, CPY_N, CPY_K)
655
+
656
+ copy(sB_tiled_copy, tCsB, transpose_fragments[warp_tile]);
657
+ }
658
+
659
+ // Make sure elements in two 8x8 warp tiles are all consumed
660
+ __syncwarp();
661
+
662
+ CUTLASS_PRAGMA_UNROLL
663
+ for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) {
664
+ Tensor tCsB_transposed = sB_thr_copy.partition_D(
665
+ flatten(s_tile_transposed(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1])))
666
+ ); // (CPY, CPY_N, CPY_K)
667
+ copy(sB_tiled_copy, transpose_fragments[warp_tile], tCsB_transposed);
668
+ }
669
+ } // lock step
670
+ } // loop warp_group_tile
671
+ }
672
+
673
+ CUTLASS_DEVICE void synchronize(int step) {
674
+ if (step == 0) {
675
+ // SMEM fence to make sure B is transposed before math
676
+ cutlass::arch::fence_view_async_shared();
677
+ cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier);
678
+ }
679
+ }
680
+
681
+ CUTLASS_DEVICE void synchronize() {
682
+ cutlass::arch::fence_view_async_shared();
683
+ cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier);
684
+ }
685
+
686
+ template <
687
+ class TensorSmemB,
688
+ class TensorTransposedSmemB>
689
+ CUTLASS_DEVICE void transpose(
690
+ TensorSmemB const& sB,
691
+ TensorTransposedSmemB const& gmma_sB,
692
+ int read_stage) {
693
+ this->operator()(sB, gmma_sB, read_stage, 0);
694
+ synchronize();
695
+ }
696
+
697
+ private:
698
+ const int warp_idx;
699
+ const int warp_group_thread_idx;
700
+ const int warp_idx_in_warp_group;
701
+ const int current_warp_tile_n_coord_LUT;
702
+ const int current_warp_tile_k_coord_LUT;
703
+ };
704
+
705
+
706
+ template<
707
+ class TiledMma,
708
+ class SmemLayoutB,
709
+ class SmemLayoutAtomB,
710
+ class ElementB,
711
+ bool TransposeB
712
+ >
713
+ constexpr CUTLASS_HOST_DEVICE
714
+ auto
715
+ make_transpose_operand_b(
716
+ int warp_idx,
717
+ int warp_group_thread_idx,
718
+ TiledMma,
719
+ SmemLayoutB,
720
+ SmemLayoutAtomB,
721
+ ElementB,
722
+ cute::bool_constant<TransposeB>)
723
+ {
724
+ if constexpr (!TransposeB) {
725
+ return NoTranspositionOperandB(
726
+ warp_idx, warp_group_thread_idx, TiledMma{},
727
+ SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{});
728
+ }
729
+ else if constexpr (use_universal_transposition<SmemLayoutAtomB, ElementB>()) {
730
+ return UniversalTranspositionOperandB(
731
+ warp_idx, warp_group_thread_idx, TiledMma{},
732
+ SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{});
733
+ }
734
+ else if constexpr (sizeof(ElementB) == 1) {
735
+ return AsyncTranspositionOperandB_1BElementB(
736
+ warp_idx, warp_group_thread_idx, TiledMma{},
737
+ SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{});
738
+ }
739
+ else {
740
+ return AsyncTranspositionOperandB(
741
+ warp_idx, warp_group_thread_idx, TiledMma{},
742
+ SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{});
743
+ }
744
+ }
745
+
746
+ }; // namespace detail
747
+
748
+ /////////////////////////////////////////////////////////////////////////////////////////////////
749
+
750
+ } // namespace collective
751
+ } // namespace transform
752
+ } // namespace cutlass
753
+
754
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/device/transform_universal_adapter.hpp ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Transform Kernel Universal adapter
33
+ */
34
+
35
+ #pragma once
36
+
37
+ // common
38
+ #include "cutlass/cutlass.h"
39
+ #include "cutlass/device_kernel.h"
40
+ #include "cutlass/gemm/gemm.h"
41
+ #include "cutlass/detail/layout.hpp"
42
+ #include "cutlass/detail/mma.hpp"
43
+ #include "cutlass/cuda_host_adapter.hpp"
44
+
45
+ #include "cutlass/kernel_launch.h"
46
+ #if !defined(__CUDACC_RTC__)
47
+ #include "cutlass/cluster_launch.hpp"
48
+ #include "cutlass/trace.h"
49
+ #endif // !defined(__CUDACC_RTC__)
50
+
51
+
52
+ ////////////////////////////////////////////////////////////////////////////////
53
+
54
+ namespace cutlass::transform::device {
55
+
56
+ ////////////////////////////////////////////////////////////////////////////////
57
+
58
+ template <class TransformKernel_>
59
+ class TransformUniversalAdapter
60
+ {
61
+ public:
62
+ using TransformKernel = GetUnderlyingKernel_t<TransformKernel_>;
63
+ using Arguments = typename TransformKernel::Arguments;
64
+ using Params = typename TransformKernel::Params;
65
+ static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER;
66
+
67
+
68
+ private:
69
+
70
+ /// Kernel API parameters object
71
+ Params params_;
72
+
73
+ public:
74
+
75
+ /// Access the Params structure
76
+ Params const& params() const {
77
+ return params_;
78
+ }
79
+
80
+ /// Determines whether the GEMM can execute the given problem.
81
+ static Status
82
+ can_implement(Arguments const& args) {
83
+ return TransformKernel::can_implement(args);
84
+ }
85
+
86
+ /// Gets the workspace size
87
+ static size_t
88
+ get_workspace_size(Arguments const& args) {
89
+ size_t workspace_bytes = 0;
90
+ workspace_bytes += TransformKernel::get_workspace_size(args);
91
+
92
+ CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
93
+
94
+ return workspace_bytes;
95
+ }
96
+
97
+ /// Computes the grid shape
98
+ static dim3
99
+ get_grid_shape(Arguments const& args, void* workspace = nullptr) {
100
+ auto tmp_params = TransformKernel::to_underlying_arguments(args, workspace);
101
+ return TransformKernel::get_grid_shape(tmp_params);
102
+ }
103
+
104
+ /// Computes the grid shape
105
+ static dim3
106
+ get_grid_shape(Params const& params) {
107
+ return TransformKernel::get_grid_shape(params);
108
+ }
109
+
110
+
111
+ /// Initializes GEMM state from arguments.
112
+ Status
113
+ initialize(
114
+ Arguments const& args,
115
+ void* workspace = nullptr,
116
+ cudaStream_t stream = nullptr,
117
+ CudaHostAdapter* cuda_adapter = nullptr) {
118
+
119
+ CUTLASS_TRACE_HOST("TransformUniversalAdapter::initialize() - workspace "
120
+ << workspace << ", stream: " << (stream ? "non-null" : "null")
121
+ << ", EnableCudaHostAdapter: " << (kEnableCudaHostAdapter ? "True" : "false"));
122
+
123
+ // Initialize the workspace
124
+ Status status = TransformKernel::initialize_workspace(args, workspace, stream, cuda_adapter);
125
+ if (status != Status::kSuccess) {
126
+ return status;
127
+ }
128
+ // Initialize the Params structure
129
+ params_ = TransformKernel::to_underlying_arguments(args, workspace);
130
+ // Don't set the function attributes - require the CudaHostAdapter to set it.
131
+ if constexpr (kEnableCudaHostAdapter) {
132
+ CUTLASS_ASSERT(cuda_adapter);
133
+ return Status::kSuccess;
134
+ }
135
+ else {
136
+ //
137
+ // Account for dynamic smem capacity if needed
138
+ //
139
+ int smem_size = TransformKernel::SharedStorageSize;
140
+
141
+ CUTLASS_ASSERT(cuda_adapter == nullptr);
142
+
143
+ if (smem_size >= (48 << 10)) {
144
+ CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
145
+ cudaError_t result = cudaFuncSetAttribute(
146
+ device_kernel<TransformKernel>,
147
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
148
+ smem_size);
149
+ if (cudaSuccess != result) {
150
+ result = cudaGetLastError(); // to clear the error bit
151
+ CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
152
+ return Status::kErrorInternal;
153
+ }
154
+ }
155
+ }
156
+ return Status::kSuccess;
157
+ }
158
+
159
+ static Status
160
+ run(Params& params,
161
+ cudaStream_t stream = nullptr,
162
+ CudaHostAdapter *cuda_adapter = nullptr,
163
+ int32_t kernel_index = 0,
164
+ bool launch_with_pdl = false) {
165
+ CUTLASS_TRACE_HOST("TransformUniversalAdapter::run()");
166
+ dim3 const block = TransformKernel::get_block_shape();
167
+ dim3 const grid = get_grid_shape(params);
168
+
169
+ // configure smem size and carveout
170
+ int smem_size = TransformKernel::SharedStorageSize;
171
+
172
+ Status launch_result{ Status::kSuccess };
173
+ // Use extended launch API only for mainloops that use it
174
+ if constexpr (TransformKernel::ArchTag::kMinComputeCapability >= 90) {
175
+ // Currently only support 1x1x1 for transform kernel.
176
+ dim3 const cluster = {1,1,1};
177
+ void* kernel_params[] = {&params};
178
+
179
+ if constexpr (kEnableCudaHostAdapter) {
180
+ //
181
+ // Use the cuda host adapter
182
+ //
183
+ CUTLASS_ASSERT(cuda_adapter);
184
+ if (cuda_adapter) {
185
+
186
+ if (launch_with_pdl) {
187
+ CUTLASS_TRACE_HOST(
188
+ "TransformUniversalAdapter::run() does not support launching with PDL and a custom cuda adapter.");
189
+ return Status::kErrorInternal;
190
+ }
191
+ launch_result = cuda_adapter->launch(grid,
192
+ cluster,
193
+ block,
194
+ smem_size,
195
+ stream,
196
+ kernel_params,
197
+ kernel_index);
198
+ CUTLASS_TRACE_HOST("Kernel Launch Result" << cutlassGetStatusString(launch_result));
199
+ }
200
+ else {
201
+ return Status::kErrorInternal;
202
+ }
203
+ }
204
+ else {
205
+ CUTLASS_ASSERT(cuda_adapter == nullptr);
206
+ void const* kernel = (void const*) device_kernel<TransformKernel>;
207
+ if constexpr (TransformKernel::ArchTag::kMinComputeCapability == 90) {
208
+ launch_result = ClusterLauncher::launch(
209
+ grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl);
210
+ }
211
+ }
212
+ }
213
+ else {
214
+ launch_result = Status::kSuccess;
215
+ cutlass::arch::synclog_setup();
216
+
217
+ if constexpr (kEnableCudaHostAdapter) {
218
+ CUTLASS_ASSERT(cuda_adapter);
219
+ if (cuda_adapter) {
220
+ void* kernel_params[] = {&params};
221
+
222
+ launch_result = cuda_adapter->launch(
223
+ grid, block, smem_size, stream, kernel_params, 0
224
+ );
225
+
226
+ }
227
+ else {
228
+ return Status::kErrorInternal;
229
+ }
230
+ }
231
+ else {
232
+ CUTLASS_ASSERT(cuda_adapter == nullptr);
233
+ cutlass::kernel_launch<TransformKernel>(grid, block, smem_size, stream, params, launch_with_pdl);
234
+ }
235
+ }
236
+
237
+ cudaError_t result = cudaGetLastError();
238
+ if (cudaSuccess == result && Status::kSuccess == launch_result) {
239
+ return Status::kSuccess;
240
+ }
241
+ else if (cudaSuccess != result) {
242
+ CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << cudaGetErrorString(result));
243
+ }
244
+ else if (Status::kSuccess != launch_result) {
245
+ CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << cutlassGetStatusString(launch_result));
246
+ }
247
+ return Status::kErrorInternal;
248
+ }
249
+
250
+ //
251
+ // Non-static launch overloads that first create and set the internal params struct of this kernel handle.
252
+ //
253
+
254
+ /// Launches the kernel after first constructing Params internal state from supplied arguments.
255
+ Status
256
+ run(
257
+ Arguments const& args,
258
+ void* workspace = nullptr,
259
+ cudaStream_t stream = nullptr,
260
+ CudaHostAdapter *cuda_adapter = nullptr,
261
+ int32_t kernel_index = 0,
262
+ bool launch_with_pdl = false
263
+ ) {
264
+ Status status = initialize(args, workspace, stream, cuda_adapter);
265
+
266
+ if (Status::kSuccess == status) {
267
+ status = run(params_, stream, cuda_adapter, kernel_index, launch_with_pdl);
268
+ }
269
+ return status;
270
+ }
271
+
272
+ /// Launches the kernel after first constructing Params internal state from supplied arguments.
273
+ Status
274
+ operator()(
275
+ Arguments const& args,
276
+ void* workspace = nullptr,
277
+ cudaStream_t stream = nullptr,
278
+ CudaHostAdapter *cuda_adapter = nullptr,
279
+ bool launch_with_pdl = false) {
280
+ return run(args, workspace, stream, cuda_adapter, 0 /*kernel_index*/, launch_with_pdl);
281
+ }
282
+
283
+ /// Overload that allows a user to re-launch the same kernel without updating internal params struct.
284
+ Status
285
+ run(
286
+ cudaStream_t stream = nullptr,
287
+ CudaHostAdapter *cuda_adapter = nullptr,
288
+ bool launch_with_pdl = false) {
289
+ return run(params_, stream, cuda_adapter, 0 /*kernel_index*/, launch_with_pdl);
290
+ }
291
+
292
+ /// Overload that allows a user to re-launch the same kernel without updating internal params struct.
293
+ Status
294
+ operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, bool launch_with_pdl = false) {
295
+ return run(params_, stream, cuda_adapter, 0 /*kernel_index*/, launch_with_pdl);
296
+ }
297
+ };
298
+
299
+ ////////////////////////////////////////////////////////////////////////////////
300
+
301
+ } // namespace cutlass::transform::device
302
+
303
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/filter_format_transformer.hpp ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /* \file
33
+ \brief Convolution filter format transformation kernel.
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include <algorithm>
39
+ #include <random>
40
+
41
+ #include "cutlass/coord.h"
42
+ #include "cutlass/arch/arch.h"
43
+ #include "cutlass/layout/matrix.h"
44
+ #include "cutlass/cuda_host_adapter.hpp"
45
+
46
+ #include "cute/int_tuple.hpp"
47
+ #include "cute/tensor.hpp"
48
+ #include "cute/config.hpp"
49
+
50
+ namespace cutlass::transform::kernel {
51
+
52
+ using namespace cute;
53
+
54
+ enum class FilterFormat {
55
+ CKTRS,
56
+ CTRSK,
57
+ KTRSC
58
+ };
59
+
60
+ template <
61
+ FilterFormat SrcFormat,
62
+ FilterFormat DstFormat,
63
+ int NumDimensions,
64
+ class Element_,
65
+ int AlignmentBytes = 16
66
+ >
67
+ struct ConvFilterFormatTransformer {
68
+
69
+ using Element = Element_;
70
+ static_assert(SrcFormat == FilterFormat::CKTRS, "Currently only source format of CKTRS is supported");
71
+ static_assert(DstFormat == FilterFormat::CTRSK || DstFormat == FilterFormat::KTRSC, "Currently only destination format of CTRSK/KTRSC is supported");
72
+ static_assert(AlignmentBytes > 0 && AlignmentBytes % static_cast<int>(sizeof(Element)) == 0, "Invalid alignment setting");
73
+
74
+ // In ktrsc order.
75
+ using FilterExtent = array<int, NumDimensions>;
76
+
77
+ // Default cta tile shape: 32x32
78
+ static constexpr auto CTATileShape = make_shape(Int<4 * AlignmentBytes / static_cast<int>(sizeof(Element))>{}, Int<32>{});
79
+ // Default thread layout: (4, 32)
80
+ static constexpr auto ThreadLayout = make_layout(make_shape(Int<4>{}, Int<32>{}));
81
+
82
+ static constexpr uint32_t MaxThreadsPerBlock = 128;
83
+ static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
84
+
85
+ using ArchTag = arch::Sm90;
86
+
87
+ // Default ctor
88
+ CUTLASS_HOST_DEVICE
89
+ ConvFilterFormatTransformer() {}
90
+
91
+ struct Arguments {
92
+ const void *src_ptr;
93
+ void *dst_ptr;
94
+ FilterExtent filter_extent;
95
+ };
96
+
97
+ struct Params {
98
+ using TensorSrc = decltype(make_tensor(make_gmem_ptr(recast_ptr<const Element>(nullptr)), make_layout(take<0,NumDimensions>(FilterExtent{}))));
99
+ using TensorDst = decltype(make_tensor(make_gmem_ptr(recast_ptr<Element>(nullptr)), make_layout(make_shape(int32_t(0), int32_t(0)))));
100
+
101
+ TensorSrc src;
102
+ TensorDst dst;
103
+ };
104
+
105
+ struct SharedStorage {
106
+ /* empty, no smem needed */
107
+ };
108
+
109
+ static constexpr int SharedStorageSize = sizeof(SharedStorage);
110
+
111
+ static Status
112
+ can_implement(Arguments const& args) {
113
+ bool implementable = true;
114
+ // alignment rule
115
+ {
116
+ int contiguous_dim = DstFormat == FilterFormat::CTRSK ? args.filter_extent[0] : args.filter_extent[NumDimensions - 1];
117
+ int align_element = AlignmentBytes / static_cast<int>(sizeof(Element));
118
+
119
+ implementable &= (contiguous_dim % align_element == 0);
120
+
121
+ if (!implementable) {
122
+ CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Alignment setting is invalid.\n");
123
+ return Status::kInvalid;
124
+ }
125
+ }
126
+
127
+ return Status::kSuccess;
128
+ }
129
+
130
+ static size_t
131
+ get_workspace_size(Arguments const& args) {
132
+ return 0;
133
+ }
134
+
135
+ static dim3
136
+ get_block_shape() {
137
+ return dim3(size(shape(ThreadLayout)), 1, 1);
138
+ }
139
+
140
+ static dim3
141
+ get_grid_shape(Params const& params) {
142
+ auto dim_m = ceil_div(size<0>(shape(params.dst)), get<0>(CTATileShape));
143
+ auto dim_n = ceil_div(size<1>(shape(params.dst)), get<1>(CTATileShape));
144
+
145
+ return dim3(dim_m, dim_n, 1);
146
+ }
147
+
148
+ static cutlass::Status
149
+ initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
150
+ CudaHostAdapter *cuda_adapter = nullptr) {
151
+ return Status::kSuccess;
152
+ }
153
+
154
+ static Params
155
+ to_underlying_arguments(Arguments const& args, void* workspace) {
156
+ auto k = args.filter_extent[0];
157
+ auto c = args.filter_extent[NumDimensions - 1];
158
+ auto srt = reverse(take<1,NumDimensions - 1>(args.filter_extent));
159
+
160
+ // source shape (s,r,t,k,c)
161
+ auto shape_src = flatten(make_shape(srt, k, c));
162
+ auto shape_dst = DstFormat == FilterFormat::CTRSK ? make_shape(k, c * product(srt)) : make_shape(c, k * product(srt));
163
+
164
+ auto src = make_tensor(make_gmem_ptr(recast_ptr<const Element>(args.src_ptr)), make_layout(shape_src));
165
+ auto dst = make_tensor(make_gmem_ptr(recast_ptr<Element>(args.dst_ptr)), make_layout(shape_dst));
166
+
167
+ return Params{src, dst};
168
+ }
169
+
170
+ CUTLASS_DEVICE
171
+ void operator()(Params const& params, char *smem_buf) {
172
+ // Tile the input tensor into blocks
173
+ auto block_coord = make_coord(blockIdx.x, blockIdx.y);
174
+ auto block_shape = make_shape(Int<4 * AlignmentBytes / static_cast<int>(sizeof(Element))>{}, Int<32>{});
175
+ // Default thread layout: (4, 32)
176
+ auto thread_layout = make_layout(make_shape(Int<4>{}, Int<32>{}));
177
+ auto vec_layout = make_layout(make_shape(Int<AlignmentBytes / static_cast<int>(sizeof(Element))>{}, Int<1>{}));
178
+
179
+ Tensor tile_D = local_tile(params.dst, block_shape, block_coord);
180
+
181
+ // Construct tiled copy
182
+ using AccessType = cutlass::AlignedArray<Element, size(vec_layout)>;
183
+ using Atom = Copy_Atom<UniversalCopy<AccessType>, Element>;
184
+
185
+ auto tiled_copy = make_tiled_copy(Atom{}, thread_layout, vec_layout);
186
+ auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
187
+ Tensor thr_tile_D = thr_copy.partition_D(tile_D);
188
+
189
+ // shape (s, r, t)
190
+ auto shape_trs = take<0, NumDimensions - 2>(shape(params.src));
191
+ // strided_c = c for format CTRSK, strided_c = k for format KTRSC
192
+ auto strided_c = DstFormat == FilterFormat::CTRSK ? get<NumDimensions - 1>(shape(params.src)) : get<NumDimensions - 2>(shape(params.src));
193
+ // shape (s, r, t, c) for format CTRSK and shape (s, r, t, k) for format KTRSC
194
+ auto shape_ctrs = append<NumDimensions - 1>(shape_trs, strided_c);
195
+ auto srtc_coord = idx2crd(int(blockIdx.y * get<1>(block_shape) + threadIdx.x / size<0>(thread_layout)), shape_ctrs);
196
+ // index of k for format CTRSK and index of c for format KTRSC
197
+ auto n_layout = make_layout(make_shape(gridDim.x, size<0>(thread_layout)), make_stride(size<0>(block_shape), size<0>(vec_layout)));
198
+ int n_idx = n_layout(make_coord(blockIdx.x, threadIdx.x % size<0>(thread_layout)));
199
+
200
+ // Fragment to load from S and store to D
201
+ auto frag = make_fragment_like(thr_tile_D);
202
+ // Predicate tensor.
203
+ Tensor thr_tile_P = make_tensor<bool>(shape(thr_tile_D));
204
+
205
+ CUTLASS_PRAGMA_UNROLL
206
+ for (int i = 0; i < size(frag); ++i) {
207
+ auto srt_coord = take<0, NumDimensions - 2>(srtc_coord);
208
+ auto kc_coord = DstFormat == FilterFormat::CTRSK ?
209
+ make_coord(n_idx+i, get<NumDimensions - 2>(srtc_coord)) :
210
+ make_coord(get<NumDimensions - 2>(srtc_coord), n_idx+i);
211
+ auto coord = flatten(make_coord(srt_coord, kc_coord));
212
+ thr_tile_P(i) = elem_less(coord, shape(params.src));
213
+ if (thr_tile_P(i)) {
214
+ frag(i) = params.src(coord);
215
+ }
216
+ }
217
+
218
+ // Copy from RMEM to GMEM
219
+ copy_if(tiled_copy, thr_tile_P, frag, thr_tile_D);
220
+ }
221
+ };
222
+
223
+ } // namespace cutlass::transform::kernel
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Compress utils specific for SM90 structure sparse kernels
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cute/container/bit_field.hpp" // cute::bit_field
39
+ #include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v, cute::uint_bit_t
40
+ #include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor
41
+ #include "cute/algorithm/cooperative_copy.hpp" // cute::cooperative_copy
42
+ #include "cutlass/arch/arch.h" // cutlass::arch::Sm90
43
+ #include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter
44
+ #include "cutlass/cutlass.h" // cutlass::Status
45
+ #include "cutlass/gemm/gemm.h" // cutlass::TagToStrideA_t
46
+ #include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up
47
+ #include "cutlass/kernel_hardware_info.h" // cutlass::KernelHardwareInfo
48
+ #include "cutlass/numeric_size.h" // cutlass::bits_to_bytes
49
+ #include "cutlass/numeric_types.h" // cutlass::has_negative_zero_v
50
+ #include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter
51
+
52
+ namespace cutlass::transform::kernel {
53
+
54
+ using namespace cute;
55
+
56
+ template<
57
+ class ProblemShape_,
58
+ class ElementA_,
59
+ class LayoutATag_,
60
+ class SparseConfig_
61
+ >
62
+ class SM90StructuredSparseCompressor {
63
+ public:
64
+ using SparseConfig = SparseConfig_;
65
+ using ProblemShape = ProblemShape_;
66
+
67
+ // * EltA
68
+ using ElementA = ElementA_;
69
+ using ElementAUint = cute::uint_bit_t<cute::sizeof_bits_v<ElementA>>;
70
+ using ElementAMma = typename SparseConfig::ElementAMma;
71
+ using ElementAMmaRaw = typename SparseConfig::ElementAMmaRaw;
72
+ using ElementAMmaRawUnit = cute::uint_bit_t<cute::sizeof_bits_v<ElementAMmaRaw>>;
73
+ using ElementASparsity = typename SparseConfig::ElementASparsity;
74
+ using ElementAMmaSparsity = typename SparseConfig::ElementAMmaSparsity;
75
+ using ElementAUintCompressed = cute::sparse_elem<ElementASparsity{}, ElementAUint>;
76
+ using LayoutATag = LayoutATag_;
77
+ using LayoutA = LayoutATag;
78
+ using StrideA = cutlass::gemm::TagToStrideA_t<LayoutATag>;
79
+
80
+ // * EltE
81
+ using ElementEMma = typename SparseConfig::ElementEMma;
82
+ using ElementEMmaRaw = typename SparseConfig::ElementEMmaRaw;
83
+ using ElementEMmaSparsity = typename SparseConfig::ElementEMmaSparsity;
84
+ // Data Type for storing one chunk's metadata
85
+ static constexpr int ElementEBitsPerChunk = typename SparseConfig::ElementEBitsPerChunk{};
86
+ CUTE_STATIC_ASSERT(ElementEBitsPerChunk == 4, "ElementEBitsPerChunk is 4 for SM90");
87
+ using ElementEChunk = cute::uint_bit_t<ElementEBitsPerChunk>;
88
+ CUTE_STATIC_ASSERT(cute::is_same_v<ElementEChunk, cute::uint4_t>, "ElementEChunk is uint4_t for SM90");
89
+ using ElementESparsityPerChunk = Int<ElementEMmaSparsity{} / (cute::sizeof_bits_v<ElementEMmaRaw> / ElementEBitsPerChunk)>;
90
+
91
+ // AtomE
92
+ using TensorEAtom = typename SparseConfig::TensorEAtom;
93
+ using TensorEAtomK = typename SparseConfig::TensorEAtomK;
94
+ using TensorEAtomM = typename SparseConfig::TensorEAtomM;
95
+
96
+ static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{};
97
+ static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{};
98
+ static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{};
99
+ static constexpr int LogicalElemsAMmaRawPerChunk = cutlass::ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw);
100
+ static constexpr int PhysicalElemsAMmaRawPerChunk = cutlass::ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw);
101
+
102
+ // * Alignment
103
+ static constexpr int TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{};
104
+ static constexpr int TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{};
105
+ static constexpr int TensorAAlignmentK = typename SparseConfig::TensorAAlignmentK{};
106
+ static constexpr int TensorAAlignmentM = typename SparseConfig::TensorAAlignmentM{};
107
+
108
+ // Required by `device_kernel`
109
+ static constexpr int MaxThreadsPerBlock = TensorEAtomM{};
110
+ static constexpr int MinBlocksPerMultiprocessor = 1;
111
+ using ArchTag = arch::Sm90;
112
+
113
+ struct SharedStorage {
114
+ ElementEMma cEsE[cute::size(TensorEAtom{})];
115
+ ElementAUintCompressed cACsAC[cute::size(TensorEAtom{})];
116
+ ElementAUint cAsA[cute::size(TensorEAtom{})];
117
+ };
118
+
119
+ static constexpr int SharedStorageSize = sizeof(SharedStorage);
120
+
121
+ struct TransformArguments {
122
+ void const* ptr_A{nullptr};
123
+ StrideA dA{};
124
+ void* ptr_ACompress{nullptr};
125
+ void* ptr_E{nullptr};
126
+ };
127
+
128
+ using TransformParams = TransformArguments;
129
+
130
+ struct Arguments {
131
+ ProblemShape problem_shape{};
132
+ TransformArguments transform{};
133
+ KernelHardwareInfo hw_info{};
134
+ };
135
+
136
+ struct Params {
137
+ ProblemShape problem_shape{};
138
+ TransformParams transform{};
139
+ KernelHardwareInfo hw_info{};
140
+ void* workspace = nullptr;
141
+ };
142
+
143
+ public:
144
+ static Params
145
+ to_underlying_arguments(Arguments const& args, void* workspace = nullptr) {
146
+ CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::to_underlying_arguments()");
147
+ return Params{{args.problem_shape},
148
+ {args.transform.ptr_A, args.transform.dA, args.transform.ptr_ACompress, args.transform.ptr_E},
149
+ {args.hw_info},
150
+ workspace};
151
+ }
152
+
153
+ static Status
154
+ can_implement(Arguments const& args) {
155
+ auto [M, N, K, L] = args.problem_shape;
156
+ if (K % LogicalElemsAPerChunk != 0) {
157
+ CUTLASS_TRACE_HOST("SM90 Sparse Compressor CAN NOT IMPLEMENT: GemmK not multiplier of logical chunk size");
158
+ return Status::kErrorInvalidProblem;
159
+ }
160
+ CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::can_implement() (True)");
161
+ return Status::kSuccess;
162
+ }
163
+
164
+ static size_t
165
+ get_workspace_size(Arguments const& args) {
166
+ CUTLASS_UNUSED(args);
167
+ // Backward compatible with host compressor
168
+ CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::get_workspace_size() (" << SharedStorageSize << ")");
169
+ return SharedStorageSize;
170
+ }
171
+
172
+ static Status
173
+ initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
174
+ CudaHostAdapter *cuda_adapter = nullptr) {
175
+ CUTLASS_UNUSED(args);
176
+ CUTLASS_UNUSED(workspace);
177
+ CUTLASS_UNUSED(stream);
178
+ CUTLASS_UNUSED(cuda_adapter);
179
+ CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::initialize_workspace()");
180
+ return Status::kSuccess;
181
+ }
182
+
183
+ static dim3
184
+ get_grid_shape(Params const& params) {
185
+ constexpr int MaxAlignmentM = cutlass::const_max(TensorEAlignmentM, TensorAAlignmentM);
186
+ constexpr int MaxAlignmentK = cutlass::const_max(TensorEAlignmentK, TensorAAlignmentK);
187
+ const auto [GemmM, GemmN, GemmK, GemmL] = params.problem_shape;
188
+
189
+ const int GemmMAlignedMax = cutlass::round_up(GemmM, MaxAlignmentM);
190
+ const int GemmKAlignedMax = cutlass::round_up(GemmK, MaxAlignmentK);
191
+
192
+ const int gridDim_X = cutlass::ceil_div(GemmMAlignedMax, TensorEAtomM{});
193
+ const int gridDim_Y = cutlass::ceil_div(GemmKAlignedMax, TensorEAtomK{});
194
+ const int gridDim_Z = GemmL;
195
+
196
+ CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::get_grid_shape() ("
197
+ << gridDim_X << ", "
198
+ << gridDim_Y << ", "
199
+ << gridDim_Z << ")");
200
+ return dim3(gridDim_X, gridDim_Y, gridDim_Z);
201
+ }
202
+
203
+ static dim3
204
+ get_block_shape() {
205
+ CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::get_block_shape() ("
206
+ << MaxThreadsPerBlock << ", "
207
+ << 1 << ", "
208
+ << 1 << ")");
209
+ return dim3(MaxThreadsPerBlock, 1, 1);
210
+ }
211
+
212
+ CUTE_DEVICE
213
+ void
214
+ operator()(Params params, void* smem_buf = nullptr) {
215
+ run(params, smem_buf);
216
+ }
217
+
218
+ CUTE_DEVICE
219
+ static void
220
+ run(Params params, void* smem_buf = nullptr) {
221
+ structure_sparse_compress(params, smem_buf);
222
+ }
223
+
224
+ private:
225
+
226
+ struct MetadataOneChunk1to2 {
227
+
228
+ CUTE_DEVICE
229
+ void set_metadata_bits(int elt_log_idx, int elt_phy_idx) {
230
+ auto metadata_bits = [&]() -> uint8_t {
231
+ CUTLASS_ASSERT(elt_log_idx >= 0 && elt_log_idx < 2);
232
+ switch (elt_log_idx) {
233
+ case 0:
234
+ return 0b0100;
235
+ case 1:
236
+ return 0b1110;
237
+ default:
238
+ CUTE_GCC_UNREACHABLE;
239
+ }
240
+ };
241
+
242
+ storage_ |= (metadata_bits() << (4 * elt_phy_idx));
243
+ }
244
+
245
+
246
+ CUTE_DEVICE
247
+ ElementEChunk storage() const {
248
+ return ElementEChunk{storage_};
249
+ }
250
+
251
+ private:
252
+ uint8_t storage_ = 0b0000;
253
+ };
254
+
255
+ struct MetadataOneChunk2to4{
256
+
257
+ CUTE_DEVICE
258
+ void set_metadata_bits(int elt_log_idx, int elt_phy_idx) {
259
+ auto metadata_bits = [&]() -> uint8_t {
260
+ CUTLASS_ASSERT(elt_log_idx >= 0 && elt_log_idx < 4);
261
+ switch (elt_log_idx) {
262
+ case 0:
263
+ return 0b00;
264
+ case 1:
265
+ return 0b01;
266
+ case 2:
267
+ return 0b10;
268
+ case 3:
269
+ return 0b11;
270
+ default:
271
+ CUTLASS_ASSERT(false);
272
+ CUTE_GCC_UNREACHABLE;
273
+ return 0b00;
274
+ }
275
+ };
276
+
277
+ storage_ |= (metadata_bits() << (2 * elt_phy_idx));
278
+ }
279
+
280
+ CUTE_DEVICE
281
+ ElementEChunk storage() const {
282
+ return ElementEChunk{storage_};
283
+ }
284
+
285
+ private:
286
+ uint8_t storage_ = 0b0000;
287
+ };
288
+
289
+ using MetadataOneChunk = cute::conditional_t<SparseConfig::IsTF32,
290
+ MetadataOneChunk1to2,
291
+ MetadataOneChunk2to4>;
292
+
293
+ private:
294
+
295
+ CUTE_DEVICE
296
+ static void
297
+ structure_sparse_compress(Params params, void* smem_buf) {
298
+ // * Input Params
299
+ auto [GemmM, GemmN, GemmK, GemmL] = params.problem_shape;
300
+ auto [ptr_A, dA, ptr_ACompress, ptr_E] = params.transform;
301
+ SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
302
+
303
+ [[maybe_unused]] const int gridDim_X = gridDim.x;
304
+ [[maybe_unused]] const int gridDim_Y = gridDim.y;
305
+ [[maybe_unused]] const int gridDim_Z = gridDim.z;
306
+ [[maybe_unused]] const int blockDim_X = blockDim.x;
307
+
308
+ // * Global Tensor Layout
309
+ const cute::Layout layout_gA = make_layout(make_shape(GemmM, GemmK, GemmL), dA);
310
+ const cute::Layout layout_gAC = SparseConfig::fill_layoutA(params.problem_shape);
311
+ const cute::Layout layout_gE = SparseConfig::fill_layoutE(params.problem_shape);
312
+
313
+ // * Construct Global Tensor
314
+ const cute::Tensor gA = make_tensor(make_gmem_ptr(cute::recast_ptr<ElementAUint>(ptr_A)), layout_gA);
315
+ cute::Tensor gAC_sparse = make_tensor(make_gmem_ptr(cute::recast_ptr<ElementAUintCompressed>(ptr_ACompress)), layout_gAC );
316
+ cute::Tensor gAC = cute::recast<ElementAUint>(gAC_sparse);
317
+ cute::Tensor gE_sparse = make_tensor(make_gmem_ptr(cute::recast_ptr<ElementEMma>(ptr_E)), layout_gE);
318
+ cute::Tensor gE = cute::recast<ElementEMmaRaw>(gE_sparse);
319
+
320
+ // * CTA Tensor Layout
321
+ using cAsA_layout_row = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{}), LayoutRight{}));
322
+ using cAsA_layout_col = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{}), LayoutLeft{}));
323
+ using cAsA_layout = cute::conditional_t<cute::is_same_v<LayoutATag, layout::RowMajor>, cAsA_layout_row, cAsA_layout_col>;
324
+ using cACsAC_layout = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{} / ElementASparsity{}), LayoutRight{}));
325
+ using cEsE_layout = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{} / ElementEMmaSparsity{}), LayoutRight{}));
326
+
327
+ CUTE_STATIC_ASSERT(cute::is_static_v<TensorEAtom>, "TensorEAtom needs to be static");
328
+ CUTE_STATIC_ASSERT(cute::is_static_v<cAsA_layout>, "cAsA_layout needs to be static");
329
+ CUTE_STATIC_ASSERT(cute::is_static_v<cACsAC_layout>, "cACsAC_layout needs to be static");
330
+ CUTE_STATIC_ASSERT(cute::is_static_v<cEsE_layout>, "cEsE_layout needs to be static");
331
+
332
+ const int blockIdx_X = blockIdx.x;
333
+ const int blockIdx_Y = blockIdx.y;
334
+ const int blockIdx_Z = blockIdx.z;
335
+ const int threadIdx_X = threadIdx.x;
336
+
337
+ // * Construct CTA Tensor
338
+ const auto cta_coord = make_coord(blockIdx_X, blockIdx_Y, blockIdx_Z);
339
+ cute::Tensor cAgA = cute::recast<ElementAMmaRawUnit>(local_tile(gA, shape(cAsA_layout{}), cta_coord));
340
+ cute::Tensor cACgAC = cute::recast<ElementAMmaRawUnit>(local_tile(gAC, shape(cACsAC_layout{}), cta_coord));
341
+ cute::Tensor cEgE = local_tile(gE, shape(cEsE_layout{}), cta_coord);
342
+
343
+ cute::Tensor cAsA = cute::recast<ElementAMmaRawUnit>(make_tensor(make_smem_ptr(cute::recast_ptr<ElementAUint>(shared_storage.cAsA)), cAsA_layout{}));
344
+ cute::Tensor cACsAC = cute::recast<ElementAMmaRawUnit>(make_tensor(make_smem_ptr(cute::recast_ptr<ElementAUint>(shared_storage.cACsAC)), cACsAC_layout{}));
345
+ cute::Tensor cEsE = make_tensor(make_smem_ptr(cute::recast_ptr<ElementEMmaRaw>(shared_storage.cEsE)), cEsE_layout{});
346
+ cute::Tensor cEsE_chunk = cute::recast<ElementEChunk>(cEsE);
347
+
348
+ // * Handle in unit of Chunk when compress
349
+ using OneChunkSizeA = Int<LogicalElemsAMmaRawPerChunk>;
350
+ using OneChunkSizeAC = Int<PhysicalElemsAMmaRawPerChunk>;
351
+ using OneChunkSizeE = Int<LogicalElemsAPerChunk / ElementESparsityPerChunk{}>;
352
+ using NumOneChunkK = Int<cutlass::ceil_div(TensorEAtomK{}, LogicalElemsAPerChunk)>;
353
+
354
+ cute::Tensor cAsA_log_chunk = logical_divide(cAsA, make_shape(_, OneChunkSizeA{}));
355
+ cute::Tensor cACsAC_log_chunk = logical_divide(cACsAC, make_shape(_, OneChunkSizeAC{}));
356
+ cute::Tensor cEsE_log_chunk = logical_divide(cEsE_chunk, make_shape(_, OneChunkSizeE{}));
357
+
358
+ // * Corner Case Handle
359
+ const auto GemmM_within_Cta = (GemmM - blockIdx_X * TensorEAtomM{} > TensorEAtomM{}) ? TensorEAtomM{} : GemmM - blockIdx_X * TensorEAtomM{};
360
+ const auto GemmK_within_Cta = ( (GemmK - blockIdx_Y * TensorEAtomK{} > TensorEAtomK{}) ? TensorEAtomK{} : GemmK - blockIdx_Y * TensorEAtomK{} ) / ElemsARawPerElementAMmaRaw;
361
+ const auto GemmK_NumOneChunk_within_Cta = GemmK_within_Cta / LogicalElemsAMmaRawPerChunk;
362
+
363
+ const auto GemmMAlignedAC = cutlass::round_up(GemmM, TensorAAlignmentM);
364
+ const auto GemmKAlignedAC = cutlass::round_up(GemmK, TensorAAlignmentK);
365
+ const auto GemmMAlignedAC_within_Cta = (GemmMAlignedAC - blockIdx_X * TensorEAtomM{} > TensorEAtomM{}) ? TensorEAtomM{} : GemmMAlignedAC - blockIdx_X * TensorEAtomM{};
366
+ const auto GemmKAlignedAC_within_Cta = ( (GemmKAlignedAC - blockIdx_Y * TensorEAtomK{} > TensorEAtomK{}) ? TensorEAtomK{} : GemmKAlignedAC - blockIdx_Y * TensorEAtomK{} ) / ElemsARawPerElementAMmaRaw;
367
+
368
+ // * Clear CTA Smem Tensor
369
+ cooperative_clear<MaxThreadsPerBlock>(threadIdx_X, cACsAC);
370
+ cooperative_clear<MaxThreadsPerBlock>(threadIdx_X, cEsE);
371
+
372
+ // * Input CTA Tensor G to S
373
+ if (GemmM_within_Cta == TensorEAtomM{} && GemmK_within_Cta == TensorEAtomK{}) {
374
+ copy_vec_pred<false, LayoutATag>(cAgA, cAsA, threadIdx_X, GemmM_within_Cta, GemmK_within_Cta);
375
+ }
376
+ else {
377
+ copy_vec_pred<true, LayoutATag>(cAgA, cAsA, threadIdx_X, GemmM_within_Cta, GemmK_within_Cta);
378
+ }
379
+
380
+ // Construct a sign bit mask for handling negative zeros
381
+ ElementAMmaRawUnit sign_mask = ElementAMmaRawUnit{ 0 };
382
+ if constexpr (has_negative_zero_v<ElementA>) {
383
+ ElementAMmaRawUnit one_sign_mask = static_cast<ElementAMmaRawUnit>(~(ElementAMmaRawUnit{ 1 } << (cute::sizeof_bits_v<ElementA> - 1)));
384
+ for (int i = 0; i < sizeof(ElementAMmaRawUnit) / sizeof(ElementAUint); ++i) {
385
+ sign_mask = static_cast<ElementAMmaRawUnit>((int32_t)sign_mask | (int32_t)one_sign_mask << (i * cute::sizeof_bits_v<ElementA>));
386
+ }
387
+ }
388
+
389
+ // * Compress
390
+ // cACsAC is always row major order
391
+ // TensorEAtomM threads perform the compression, each thread compress one row
392
+ const int row_i = threadIdx_X;
393
+ if (row_i < GemmM_within_Cta) {
394
+
395
+ CUTE_UNROLL
396
+ for (int col_chunk_i = 0; col_chunk_i < NumOneChunkK{}; ++col_chunk_i) {
397
+ if (col_chunk_i < GemmK_NumOneChunk_within_Cta) {
398
+ // Compress is handled in unit of ElementAMmaRawUnit
399
+ cute::Tensor tAsA = cAsA_log_chunk(row_i, make_coord(_, col_chunk_i));
400
+ cute::Tensor tACsAC = cACsAC_log_chunk(row_i, make_coord(_, col_chunk_i));
401
+ cute::Tensor tEsE = cEsE_log_chunk(row_i, make_coord(_, col_chunk_i));
402
+
403
+ int non_zero_cnt = 0;
404
+ // None zero element indx
405
+ // e.g.
406
+ // 2:4 sparsity [x 0 0 x]
407
+ // non_zero_elt_log_idx = [0, 3]
408
+ int non_zero_elt_log_idx[OneChunkSizeAC{}] = { 0 };
409
+
410
+ // * Find None Zero Element Idx within Chunk
411
+ CUTE_UNROLL
412
+ for (int elt_log_idx = 0; elt_log_idx < OneChunkSizeA{}; ++elt_log_idx) {
413
+ ElementAMmaRawUnit elem_A = tAsA[elt_log_idx];
414
+
415
+ // Handle negative 0
416
+ ElementAMmaRawUnit masked_elem_A = elem_A;
417
+ if constexpr (has_negative_zero_v<ElementA>) {
418
+ masked_elem_A = elem_A & sign_mask;
419
+ }
420
+
421
+ if (masked_elem_A != ElementAMmaRawUnit{0}) {
422
+ non_zero_elt_log_idx[non_zero_cnt] = elt_log_idx;
423
+ tACsAC[non_zero_cnt] = elem_A;
424
+ non_zero_cnt++;
425
+ }
426
+ }
427
+
428
+ // * Corner Case for 2:4 sparsity
429
+ if constexpr (cute::sizeof_bits_v<ElementAMmaRawUnit> < 32) {
430
+ // i.e. [0 0 0 x] -> [(0) 0 0 x]
431
+ if (non_zero_cnt == 1 && non_zero_elt_log_idx[0] == 3) {
432
+ tACsAC[1] = tACsAC[0];
433
+ tACsAC[0] = ElementAMmaRawUnit{0};
434
+ non_zero_elt_log_idx[0] = 0;
435
+ non_zero_elt_log_idx[1] = 3;
436
+ }
437
+ // i.e. [0 0 x 0] -> [0 0 x (0)]
438
+ // i.e. [0 x 0 0] -> [0 x 0 (0)]
439
+ // i.e. [x 0 0 0] -> [x 0 0 (0)]
440
+ else if (non_zero_cnt == 1) {
441
+ tACsAC[1] = ElementAMmaRawUnit{0};
442
+ non_zero_elt_log_idx[1] = 3;
443
+ }
444
+ }
445
+
446
+ // * Set Metadata Bits
447
+ MetadataOneChunk metadata_one_chunk;
448
+ CUTE_UNROLL
449
+ for (int elt_phy_idx = 0; elt_phy_idx < OneChunkSizeAC{}; elt_phy_idx++) {
450
+ metadata_one_chunk.set_metadata_bits(non_zero_elt_log_idx[elt_phy_idx], elt_phy_idx);
451
+ }
452
+ tEsE[0] = metadata_one_chunk.storage();
453
+
454
+ }
455
+ else {
456
+ break;
457
+ }
458
+ }
459
+ }
460
+
461
+ // * Sync after Compress
462
+ __syncthreads();
463
+
464
+ // * Output Cta Tensor S to G
465
+ if (GemmM_within_Cta > 0 && GemmK_within_Cta > 0) {
466
+ constexpr int MaxVecBits = 128; // STG.128
467
+ cute::cooperative_copy<MaxThreadsPerBlock, MaxVecBits>(threadIdx_X, cEsE, cEgE);
468
+ }
469
+
470
+ if (GemmMAlignedAC_within_Cta == TensorEAtomM{} && GemmKAlignedAC_within_Cta == TensorEAtomK{}) {
471
+ copy_vec_pred<false, LayoutATag>(cACsAC, cACgAC, threadIdx_X, GemmMAlignedAC_within_Cta, (GemmKAlignedAC_within_Cta / ElementASparsity::value));
472
+ }
473
+ else {
474
+ copy_vec_pred<true, LayoutATag>(cACsAC, cACgAC, threadIdx_X, GemmMAlignedAC_within_Cta, (GemmKAlignedAC_within_Cta / ElementASparsity::value));
475
+ }
476
+
477
+ } // end of structure_sparse_compress()
478
+
479
+ template<uint32_t NumThreads,
480
+ typename TensorSrc>
481
+ CUTE_DEVICE
482
+ static void
483
+ cooperative_clear(
484
+ uint32_t const& tid,
485
+ TensorSrc dSrc) {
486
+
487
+ auto dSrctSrc = local_partition(dSrc, make_layout(make_shape(NumThreads, _1{})), tid);
488
+ cute::clear(dSrctSrc);
489
+
490
+ // Sync all thread data access
491
+ __syncthreads();
492
+ }
493
+
494
+ template <bool pred,
495
+ typename LayoutTag,
496
+ typename TensorSrc,
497
+ typename TensorDst>
498
+ CUTE_DEVICE
499
+ static void
500
+ copy_vec_pred(
501
+ TensorSrc dSrc,
502
+ TensorDst dDst,
503
+ int threadIdx_X,
504
+ int valid_rows,
505
+ int valid_cols) {
506
+
507
+ constexpr bool IsRowMajor = cute::is_same_v<LayoutTag, cutlass::layout::RowMajor>;
508
+ using Element = typename TensorSrc::element_type;
509
+ constexpr bool IsQmmaF6 = cute::sizeof_bits_v<Element> == 6;
510
+
511
+ CUTE_STATIC_ASSERT(cute::is_static_v<decltype(shape(dSrc))>, "shape(dSrc) needs to be static");
512
+ CUTE_STATIC_ASSERT(cute::is_static_v<decltype(shape(dDst))>, "shape(dDst) needs to be static");
513
+ CUTE_STATIC_ASSERT(cute::sizeof_bits_v<typename TensorSrc::element_type> == cute::sizeof_bits_v<typename TensorDst::element_type>,
514
+ "dSrc and dDst need to have same element bit width");
515
+ CUTE_STATIC_ASSERT(cute::size(dSrc) == cute::size(dDst), "dSrc and dDst need to have same size");
516
+
517
+ // ValueShape
518
+ using ValueShape =
519
+ cute::conditional_t<IsQmmaF6,
520
+ Shape<Int<1>, Int<1>>,
521
+ cute::conditional_t<IsRowMajor,
522
+ Shape<Int<1>, Int<128 / sizeof_bits_v<Element>>>,
523
+ Shape<Int<128 / sizeof_bits_v<Element>>, Int<1>>>
524
+ >;
525
+
526
+ constexpr int ValueShapeRows = shape<0>(ValueShape{});
527
+ constexpr int ValueShapeCols = shape<1>(ValueShape{});
528
+
529
+ // ThreadShape
530
+ using ThreadShape =
531
+ cute::conditional_t<IsQmmaF6,
532
+ cute::conditional_t<IsRowMajor,
533
+ Shape<Int<MaxThreadsPerBlock>, Int<1>>,
534
+ Shape<Int<1>, Int<MaxThreadsPerBlock>>>,
535
+ cute::conditional_t<IsRowMajor,
536
+ Shape<Int<MaxThreadsPerBlock / (shape<1>(dSrc) / ValueShapeCols)>, Int< (shape<1>(dSrc) / ValueShapeCols)>>,
537
+ Shape<Int< (shape<0>(dSrc) / ValueShapeRows)>, Int<MaxThreadsPerBlock / (shape<0>(dSrc) / ValueShapeRows)>>>
538
+ >;
539
+
540
+ constexpr int ThreadShapeRows = shape<0>(ThreadShape{});
541
+ constexpr int ThreadShapeCols = shape<1>(ThreadShape{});
542
+
543
+ const int threadIdx_X_row = threadIdx_X / ThreadShapeCols;
544
+ const int threadIdx_X_col = threadIdx_X % ThreadShapeCols;
545
+
546
+ // Row Major
547
+ if constexpr (IsRowMajor) {
548
+ CUTE_UNROLL
549
+ for (int iter_row_blk = 0; iter_row_blk < cutlass::ceil_div(shape<0>(dSrc), ThreadShapeRows * ValueShapeRows); ++iter_row_blk) {
550
+ CUTE_UNROLL
551
+ for (int col_chunk_i = 0; col_chunk_i < cutlass::ceil_div(shape<1>(dSrc) , ThreadShapeCols * ValueShapeCols); ++col_chunk_i) {
552
+ CUTE_UNROLL
553
+ for (int iter_row_thr = 0; iter_row_thr < ValueShapeRows; ++iter_row_thr) {
554
+ CUTE_UNROLL
555
+ for (int iter_col_thr = 0; iter_col_thr < ValueShapeCols; ++iter_col_thr) {
556
+ const int row_i = (iter_row_blk * ThreadShapeRows + threadIdx_X_row) * ValueShapeRows + iter_row_thr;
557
+ const int col_i = (col_chunk_i * ThreadShapeCols + threadIdx_X_col) * ValueShapeCols + iter_col_thr;
558
+ if constexpr ( (not pred) and (not IsQmmaF6) ) {
559
+ dDst(row_i, col_i) = dSrc(row_i, col_i);
560
+ }
561
+ else {
562
+ if (row_i < valid_rows && col_i < valid_cols) {
563
+ dDst(row_i, col_i) = dSrc(row_i, col_i);
564
+ }
565
+ }
566
+ }
567
+ }
568
+ }
569
+ }
570
+ }
571
+ // Col Major
572
+ else {
573
+ CUTE_UNROLL
574
+ for (int col_chunk_i = 0; col_chunk_i < cutlass::ceil_div(shape<1>(dSrc) , ThreadShapeCols * ValueShapeCols); ++col_chunk_i) {
575
+ CUTE_UNROLL
576
+ for (int iter_row_blk = 0; iter_row_blk < cutlass::ceil_div(shape<0>(dSrc), ThreadShapeRows * ValueShapeRows); ++iter_row_blk) {
577
+ CUTE_UNROLL
578
+ for (int iter_col_thr = 0; iter_col_thr < ValueShapeCols; ++iter_col_thr) {
579
+ CUTE_UNROLL
580
+ for (int iter_row_thr = 0; iter_row_thr < ValueShapeRows; ++iter_row_thr) {
581
+ const int row_i = (iter_row_blk * ThreadShapeRows + threadIdx_X_row) * ValueShapeRows + iter_row_thr;
582
+ const int col_i = (col_chunk_i * ThreadShapeCols + threadIdx_X_col) * ValueShapeCols + iter_col_thr;
583
+ if constexpr ( (not pred) and (not IsQmmaF6) ) {
584
+ dDst(row_i, col_i) = dSrc(row_i, col_i);
585
+ }
586
+ else {
587
+ if (row_i < valid_rows && col_i < valid_cols) {
588
+ dDst(row_i, col_i) = dSrc(row_i, col_i);
589
+ }
590
+ }
591
+ }
592
+ }
593
+ }
594
+ }
595
+ }
596
+
597
+ // Sync all thread data access
598
+ __syncthreads();
599
+ } // end of copy_vec_pred()
600
+
601
+ };
602
+
603
+ } // namespace cutlass::transform::kernel
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sparse_gemm_compressor.hpp ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Compress utils for structured sparse kernels
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include <algorithm> // std::fill
39
+ #include <array> // std::array
40
+ #include <random> // std::mt19937
41
+
42
+ #include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v
43
+ #include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor
44
+ #include "cutlass/arch/arch.h" // cutlass::arch::SmXY
45
+ #include "cutlass/detail/dependent_false.hpp" // cutlass::detail::dependent_false
46
+ #include "cutlass/gemm/gemm.h" // cutlass::TagToStrideA_t
47
+ #include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up
48
+ #include "cutlass/numeric_size.h" // cutlass::bits_to_bytes
49
+
50
+ #include "cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp"
51
+
52
+ namespace cutlass::transform::kernel {
53
+
54
+ template<
55
+ class ProblemShape_,
56
+ class ElementA_,
57
+ class LayoutATag_,
58
+ class SparseConfig_
59
+ >
60
+ class StructuredSparseCompressorUtility {
61
+ public:
62
+ using SparseConfig = SparseConfig_;
63
+ using ProblemShape = ProblemShape_;
64
+
65
+ //* EltA
66
+ using ElementA = ElementA_;
67
+ using LayoutATag = LayoutATag_;
68
+ using StrideA = cutlass::gemm::TagToStrideA_t<LayoutATag>;
69
+ using ElementAMmaRaw = typename SparseConfig::ElementAMmaRaw;
70
+ using ElementASparsity = typename SparseConfig::ElementASparsity;
71
+ using ElementAMmaSparsity = typename SparseConfig::ElementAMmaSparsity;
72
+
73
+ //* EltE
74
+ using ElementEMmaRaw = typename SparseConfig::ElementEMmaRaw;
75
+ using ElementEMmaSparsity = typename SparseConfig::ElementEMmaSparsity;
76
+
77
+ //* AtomE
78
+ using TensorEAtom = typename SparseConfig::TensorEAtom;
79
+ using TensorEAtomK = typename SparseConfig::TensorEAtomK;
80
+ using TensorEAtomM = typename SparseConfig::TensorEAtomM;
81
+
82
+ static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{};
83
+ static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{};
84
+ static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{};
85
+ static constexpr int LogicalElemsAMmaRawPerChunk = cutlass::ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw);
86
+ static constexpr int PhysicalElemsAMmaRawPerChunk = cutlass::ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw);
87
+
88
+ //* Alignment
89
+ static constexpr int TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{};
90
+ static constexpr int TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{};
91
+ static constexpr int TensorAAlignmentK = typename SparseConfig::TensorAAlignmentK{};
92
+ static constexpr int TensorAAlignmentM = typename SparseConfig::TensorAAlignmentM{};
93
+
94
+ StructuredSparseCompressorUtility() = default;
95
+
96
+ StructuredSparseCompressorUtility(ProblemShape problem, StrideA dA) {
97
+ set_problem_size(problem, dA);
98
+ }
99
+
100
+ void set_problem_size(ProblemShape problem, StrideA dA_) {
101
+ M = cute::size<0>(problem);
102
+ K = cute::size<2>(problem);
103
+ L = cute::size<3>(problem);
104
+
105
+ // The following three vars are logical elem count!
106
+ K_alignedA = round_up(K, TensorAAlignmentK);
107
+ M_alignedA = round_up(M, TensorAAlignmentM);
108
+ K_alignedE = round_up(K, TensorEAlignmentK);
109
+ M_alignedE = round_up(M, TensorEAlignmentM);
110
+
111
+ dA = dA_;
112
+ }
113
+
114
+ /**
115
+ * @brief Get the TensorE number of ElementE along K after alignment requirement
116
+ *
117
+ * @return int : number of ElementE (uint8_t) along K-dim
118
+ */
119
+ int get_metadata_m_physical() const {
120
+ return M_alignedE;
121
+ }
122
+
123
+ /**
124
+ * @brief Get the TensorE number of ElementE along M after alignment requirement
125
+ *
126
+ * @return int : number of ElementE (uint8_t) along M-dim
127
+ */
128
+ int get_metadata_k_physical() const {
129
+ return K_alignedE / ElementEMmaSparsity{};
130
+ }
131
+
132
+ /**
133
+ * @brief Get the TensorACompressed number of ElementA along K after alignment requirement
134
+ *
135
+ * @return int : number of ElementA along K-dim
136
+ */
137
+ int get_tensorA_k_physical() const {
138
+ return K_alignedA / ElementASparsity{};
139
+ }
140
+
141
+ /**
142
+ * @brief Get the TensorACompressed number of ElementA along M after alignment requirement
143
+ *
144
+ * @return int : number of ElementA along M-dim
145
+ */
146
+ int get_tensorA_m_physical() const {
147
+ return M_alignedA;
148
+ }
149
+
150
+ /**
151
+ * @brief Get the TensorACompressed Bytes
152
+ *
153
+ * @return uint64_t bytes
154
+ */
155
+ uint64_t get_compressed_tensor_A_bytes() const {
156
+ const auto tensor_a_comp_num_elt_a = get_tensorA_m_physical() * get_tensorA_k_physical() * L;
157
+ const auto tensor_a_comp_bytes = cutlass::bits_to_bytes<uint64_t>(tensor_a_comp_num_elt_a * cute::sizeof_bits_v<ElementA>);
158
+ return tensor_a_comp_bytes;
159
+ }
160
+
161
+ /**
162
+ * @brief Get the TensorA Bytes
163
+ *
164
+ * @return uint64_t bytes
165
+ */
166
+ uint64_t get_raw_tensor_A_bytes() const {
167
+ const auto tensor_a_num_elt_a = uint64_t(M) * uint64_t(K) * uint64_t(L);
168
+ const auto tensor_a_bytes = cutlass::bits_to_bytes<uint64_t>(tensor_a_num_elt_a * cute::sizeof_bits_v<ElementA>);
169
+ return tensor_a_bytes;
170
+ }
171
+
172
+ /**
173
+ * @brief Get the TensorE Bytes
174
+ *
175
+ * @return uint64_t bytes
176
+ */
177
+ uint64_t get_tensor_E_bytes() const {
178
+ const auto tensor_e_num_elt_a = uint64_t(get_metadata_m_physical()) * uint64_t(get_metadata_k_physical()) * uint64_t(L);
179
+ const auto tensor_e_bytes = cutlass::bits_to_bytes<uint64_t>(tensor_e_num_elt_a * cute::sizeof_bits_v<ElementEMmaRaw>);
180
+ return tensor_e_bytes;
181
+ }
182
+
183
+ constexpr auto fill_layoutA_from_compressor() const {
184
+ return SparseConfig::fill_layoutA(cute::make_tuple(M,_1{},K,L));
185
+ }
186
+
187
+ constexpr auto fill_layoutE_from_compressor() const {
188
+ return SparseConfig::fill_layoutE(cute::make_tuple(M,_1{},K,L));
189
+ }
190
+
191
+ void structure_sparse_zero_mask_fill(void* host_a_ptr, uint64_t seed) {
192
+
193
+ constexpr int ChunkSize = LogicalElemsAMmaRawPerChunk;
194
+ using ChunkElement = cute::uint_bit_t<cute::sizeof_bits_v<ElementAMmaRaw>>;
195
+
196
+ cute::Tensor gA_eltA = cute::make_tensor(
197
+ cute::recast_ptr<ElementA>(host_a_ptr),
198
+ cute::make_layout(make_shape(M, K, L), dA));
199
+
200
+ // Input TensorA is handled in unit of ElementAMmaRaw instead of ElementA
201
+ cute::Tensor gA = cute::recast<ChunkElement>(gA_eltA);
202
+
203
+ // Extract out the Chunk from K-mode
204
+ Tensor gA_chunk = cute::zipped_divide(gA, cute::Shape<_1,cute::Int<ChunkSize>>{}); // (Chunk, Rest)
205
+
206
+ // Half of the data is zero to indicate sparsityA = 2
207
+ std::array<int, ChunkSize> nnzb_indicator{};
208
+ for (size_t i = 1; i < nnzb_indicator.size(); i += 2) {
209
+ nnzb_indicator.at(i) = 1;
210
+ }
211
+
212
+ std::mt19937 rng(seed);
213
+ auto rest_shape = cute::shape<1>(gA_chunk);
214
+ for (auto iter = cute::make_coord_iterator(rest_shape); iter != cute::ForwardCoordIteratorSentinel{}; ++iter) {
215
+ std::shuffle(nnzb_indicator.begin(), nnzb_indicator.end(), rng);
216
+ for (int c = 0; c < size<0>(gA_chunk); ++c) { // for each elem within chunk
217
+ if (nnzb_indicator[c] == 0) {
218
+ gA_chunk(c, *iter) = ChunkElement{0};
219
+ }
220
+ } // end of within chunk
221
+ } // end of chunk_idx
222
+ }
223
+
224
+ int M{-1};
225
+ int K{-1};
226
+ int L{-1};
227
+ StrideA dA{};
228
+
229
+ private:
230
+ int K_alignedA{-1};
231
+ int M_alignedA{-1};
232
+ int K_alignedE{-1};
233
+ int M_alignedE{-1};
234
+ };
235
+
236
+ ////////////////////////////////////////////////////////////////////////////////
237
+
238
+ template<
239
+ class ProblemShape,
240
+ class ElementA,
241
+ class LayoutATag,
242
+ class SparseConfig,
243
+ class ArchTag
244
+ >
245
+ struct StructuredSparseCompressorSelector {
246
+ static_assert(cutlass::detail::dependent_false<ArchTag>,
247
+ "Could not select a structured sparse compressor for given parameters.");
248
+ };
249
+
250
+ template<
251
+ class ProblemShape,
252
+ class ElementA,
253
+ class LayoutATag,
254
+ class SparseConfig
255
+ >
256
+ struct StructuredSparseCompressorSelector<
257
+ ProblemShape,
258
+ ElementA,
259
+ LayoutATag,
260
+ SparseConfig,
261
+ arch::Sm90> {
262
+ using Compressor = SM90StructuredSparseCompressor<
263
+ ProblemShape,
264
+ ElementA,
265
+ LayoutATag,
266
+ SparseConfig
267
+ >;
268
+ };
269
+
270
+ template<
271
+ class ProblemShape,
272
+ class ElementA,
273
+ class LayoutATag,
274
+ class SparseConfig
275
+ >
276
+ struct StructuredSparseCompressorSelector<
277
+ ProblemShape,
278
+ ElementA,
279
+ LayoutATag,
280
+ SparseConfig,
281
+ arch::Sm100> {
282
+ using Compressor = SM90StructuredSparseCompressor<
283
+ ProblemShape,
284
+ ElementA,
285
+ LayoutATag,
286
+ SparseConfig
287
+ >;
288
+ };
289
+
290
+ template<
291
+ class ProblemShape,
292
+ class ElementA,
293
+ class LayoutATag,
294
+ class SparseConfig
295
+ >
296
+ struct StructuredSparseCompressorSelector<
297
+ ProblemShape,
298
+ ElementA,
299
+ LayoutATag,
300
+ SparseConfig,
301
+ arch::Sm120> {
302
+ using Compressor = SM90StructuredSparseCompressor<
303
+ ProblemShape,
304
+ ElementA,
305
+ LayoutATag,
306
+ SparseConfig
307
+ >;
308
+ };
309
+
310
+ template<
311
+ class ProblemShape,
312
+ class ElementA,
313
+ class LayoutATag,
314
+ class SparseConfig,
315
+ class ArchTag
316
+ >
317
+ using StructuredSparseCompressor = typename StructuredSparseCompressorSelector<
318
+ ProblemShape,
319
+ ElementA,
320
+ LayoutATag,
321
+ SparseConfig,
322
+ ArchTag
323
+ >::Compressor;
324
+
325
+ } // End namespace cutlass::transform::kernel
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/pitch_linear_thread_map.h ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates implementing how threads are mapped to a given tile.
33
+
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/cutlass.h"
39
+ #include "cutlass/array.h"
40
+ #include "cutlass/coord.h"
41
+ #include "cutlass/predicate_vector.h"
42
+ #include "cutlass/tensor_ref.h"
43
+ #include "cutlass/tensor_view.h"
44
+ #include "cutlass/layout/pitch_linear.h"
45
+
46
+ ////////////////////////////////////////////////////////////////////////////////
47
+
48
+ namespace cutlass {
49
+ namespace transform {
50
+
51
+ ////////////////////////////////////////////////////////////////////////////////
52
+
53
+ /// Strip-mines a pitch-linear tile among a given number of threads, first along
54
+ /// the contiguous dimension then along the strided dimension.
55
+ ///
56
+ /// The tile must be divisible by the thread count such that all threads may
57
+ /// execute the same number of iterations with the same delta to exhaustively
58
+ /// cover the tile.
59
+ ///
60
+ /// This class satisfies the "RegularThreadMapping" concept.
61
+ ///
62
+ /// This ThreadMap is used by SIMT kernels and operand E of the sparse tensor
63
+ /// kernels.
64
+ template <
65
+ typename Shape_,
66
+ int Threads,
67
+ int ElementsPerAccess = 1
68
+ >
69
+ struct PitchLinearStripminedThreadMap {
70
+
71
+ /// Tensor coordinate
72
+ using TensorCoord = layout::PitchLinearCoord;
73
+
74
+ /// Tile shape
75
+ using Shape = Shape_;
76
+
77
+ /// Number of threads total
78
+ static int const kThreads = Threads;
79
+
80
+ /// Extract vector length from Layout
81
+ static int const kElementsPerAccess = ElementsPerAccess;
82
+
83
+ /// Shape of access by each thread
84
+ using ThreadAccessShape = layout::PitchLinearShape<kElementsPerAccess, 1>;
85
+
86
+ /// Internal implementation details
87
+ struct Detail {
88
+
89
+ static_assert(!(Shape::kContiguous % kElementsPerAccess), "");
90
+
91
+ /// Shape of the tile in units of vectors
92
+ using ShapeVec = layout::PitchLinearShape<
93
+ Shape::kContiguous / kElementsPerAccess,
94
+ Shape::kStrided
95
+ >;
96
+
97
+ static_assert((Threads < ShapeVec::kContiguous && !(ShapeVec::kContiguous % kThreads)) ||
98
+ (!(kThreads % ShapeVec::kContiguous)),
99
+ "Shape must be divisible by number of iterations of each thread.");
100
+ };
101
+
102
+ /// Number of iterations by each thread
103
+ using Iterations = typename platform::conditional<
104
+ Threads >= Detail::ShapeVec::kContiguous,
105
+ layout::PitchLinearShape<
106
+ 1,
107
+ // Redo the comparison here to work around divide by zero compiler
108
+ // error. The compiler evaluates both path of platform::conditional.
109
+ (Threads >= Detail::ShapeVec::kContiguous
110
+ ? (Detail::ShapeVec::kStrided + (kThreads / Detail::ShapeVec::kContiguous - 1)) /
111
+ (kThreads / Detail::ShapeVec::kContiguous)
112
+ : 0)>,
113
+ layout::PitchLinearShape<Detail::ShapeVec::kContiguous / kThreads,
114
+ Detail::ShapeVec::kStrided>>::type;
115
+
116
+
117
+ /// Interval between accesses along each dimension of the tensor's logical coordinate space
118
+ /// (in units of Elements)
119
+ using Delta = typename platform::conditional<
120
+ Threads >= Detail::ShapeVec::kContiguous,
121
+ layout::PitchLinearShape<
122
+ 1,
123
+ kThreads / Detail::ShapeVec::kContiguous
124
+ >,
125
+ layout::PitchLinearShape<
126
+ kThreads * kElementsPerAccess,
127
+ 1
128
+ >
129
+ >::type;
130
+
131
+ /// Shape of the tile in units of vectors
132
+ using StorageShape = typename platform::conditional<
133
+ Threads >= Detail::ShapeVec::kContiguous,
134
+ layout::PitchLinearShape<Shape::kContiguous,
135
+ Iterations::kStrided*(kThreads / Detail::ShapeVec::kContiguous)>,
136
+ layout::PitchLinearShape<Shape::kContiguous, Shape::kStrided>>::type;
137
+
138
+ /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space
139
+ /// (in units of Elements)
140
+ CUTLASS_HOST_DEVICE
141
+ static TensorCoord initial_offset(int thread_id) {
142
+ return TensorCoord(
143
+ (thread_id % Detail::ShapeVec::kContiguous) * kElementsPerAccess,
144
+ thread_id / Detail::ShapeVec::kContiguous);
145
+ }
146
+ };
147
+
148
+ /// This ThreadMap is used by GEMV
149
+ template <
150
+ typename Shape,
151
+ int Threads,
152
+ int ElementsPerAccess = 1
153
+ >
154
+ struct PitchLinearTilePolicyStripminedThreadContiguous
155
+ {
156
+ static_assert((Shape::kContiguous % (Threads * ElementsPerAccess)) == 0,
157
+ "Contiguous shape must divide number of threads");
158
+
159
+ using TensorCoord = layout::PitchLinearCoord;
160
+
161
+ static int const kThreads = Threads;
162
+ static int const kElementsPerAccess = ElementsPerAccess;
163
+
164
+ using Iterations = layout::PitchLinearShape<
165
+ Shape::kContiguous / (kThreads * kElementsPerAccess),
166
+ Shape::kStrided>;
167
+
168
+ using Delta = layout::PitchLinearShape<1, 1>;
169
+
170
+ CUTLASS_HOST_DEVICE
171
+ static TensorCoord initial_offset(int thread_id)
172
+ {
173
+ return TensorCoord(thread_id * Iterations::kContiguous * kElementsPerAccess, 0);
174
+ }
175
+ };
176
+
177
+ template <
178
+ typename Shape,
179
+ int Threads,
180
+ int ElementsPerAccess = 1
181
+ >
182
+ struct PitchLinearTilePolicyStripminedThreadStrided
183
+ {
184
+ static_assert((Shape::kStrided % Threads == 0),
185
+ "Strided shape must divide number of threads");
186
+
187
+ using TensorCoord = layout::PitchLinearCoord;
188
+
189
+ static int const kThreads = Threads;
190
+ static int const kElementsPerAccess = ElementsPerAccess;
191
+
192
+ using Iterations = layout::PitchLinearShape<
193
+ Shape::kContiguous / kElementsPerAccess,
194
+ Shape::kStrided / kThreads>;
195
+
196
+ using Delta = layout::PitchLinearShape<1, 1>;
197
+
198
+ using ShapeVec = Shape;
199
+
200
+ CUTLASS_HOST_DEVICE
201
+ static TensorCoord initial_offset(int thread_id)
202
+ {
203
+
204
+ return TensorCoord(0, thread_id * Iterations::kStrided);
205
+ }
206
+ };
207
+
208
+
209
+ ////////////////////////////////////////////////////////////////////////////////
210
+
211
+ /// Policy defining a warp-raked arrangement in which a shape is partitioned into contiguous
212
+ /// elements.
213
+ ///
214
+ /// This ThreadMap is used by tensor core kernels.
215
+ template <
216
+ typename Shape_,
217
+ int Threads,
218
+ typename WarpThreadArrangement_,
219
+ int ElementsPerAccess = 1
220
+ >
221
+ struct PitchLinearWarpRakedThreadMap {
222
+
223
+ /// Tensor coordinate
224
+ using TensorCoord = layout::PitchLinearCoord;
225
+
226
+ /// Tile shape
227
+ using Shape = Shape_;
228
+
229
+ /// Number of threads total
230
+ static int const kThreads = Threads;
231
+
232
+ /// Extract vector length from Layout
233
+ static int const kElementsPerAccess = ElementsPerAccess;
234
+
235
+ /// Shape of access by each thread
236
+ using ThreadAccessShape = layout::PitchLinearShape<kElementsPerAccess, 1>;
237
+
238
+ /// Internal details made public to facilitate introspection
239
+ struct Detail {
240
+
241
+ /// Fixed arrangement of threads within a warp (units of threads).
242
+ using WarpThreadArrangement = WarpThreadArrangement_;
243
+
244
+ /// Number of threads per warp
245
+ static int const kWarpSize = WarpThreadArrangement::kCount;
246
+
247
+ /// Number of participating warps
248
+ static int const kWarpCount = kThreads / kWarpSize;
249
+
250
+ static_assert(
251
+ !(Shape::kContiguous % kElementsPerAccess),
252
+ "Shape must be divisible by vector length.");
253
+
254
+ /// Compute the 'shape' of the overall tile in units of vectors
255
+ using ShapeInAccesses = layout::PitchLinearShape<
256
+ Shape::kContiguous / kElementsPerAccess,
257
+ Shape::kStrided
258
+ >;
259
+
260
+ static_assert(
261
+ !(ShapeInAccesses::kContiguous % WarpThreadArrangement::kContiguous),
262
+ "ShapeInAccesses must be divisible by WarpThreadArrangement.");
263
+
264
+ static_assert(
265
+ !(ShapeInAccesses::kStrided % WarpThreadArrangement::kStrided),
266
+ "ShapeInAccesses must be divisible by WarpThreadArrangement.");
267
+
268
+ // compute number of warp-level accesses total
269
+ using WarpAccessIterations = layout::PitchLinearShape<
270
+ ShapeInAccesses::kContiguous / WarpThreadArrangement::kContiguous,
271
+ ShapeInAccesses::kStrided / WarpThreadArrangement::kStrided
272
+ >;
273
+
274
+ // Divide it into the number of warps, first partitioning the strided dimension then the
275
+ // contiguous.
276
+ static int const kWarpsStrided =
277
+ (WarpAccessIterations::kStrided >= kWarpCount
278
+ ? kWarpCount
279
+ : WarpAccessIterations::kStrided);
280
+
281
+ static int const kWarpsContiguous =
282
+ (kWarpCount > WarpAccessIterations::kStrided
283
+ ? kWarpCount / kWarpsStrided
284
+ : 1);
285
+
286
+ /// Arrangement of warps within a threadblock-scoped tile
287
+ using WarpArrangement = layout::PitchLinearShape<
288
+ kWarpsContiguous, kWarpsStrided
289
+ >;
290
+ };
291
+
292
+ ///< Iterations along each dimension (concept: PitchLinearShape)
293
+ using Iterations = layout::PitchLinearShape<
294
+ Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous,
295
+ Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided
296
+ >;
297
+
298
+ static_assert(Iterations::kCount,
299
+ "Number of iterations must be non-zero");
300
+
301
+ ///< Delta between accesses (units of elements, concept: PitchLinearShape)
302
+ using Delta = layout::PitchLinearShape<
303
+ Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess,
304
+ Detail::WarpThreadArrangement::kStrided
305
+ >;
306
+
307
+ /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space
308
+ CUTLASS_HOST_DEVICE
309
+ static TensorCoord initial_offset(int thread_id) {
310
+
311
+ int warp_id = (thread_id / Detail::kWarpSize);
312
+ int lane_id = (thread_id % Detail::kWarpSize);
313
+
314
+ //
315
+ // compute warp-level offset
316
+ //
317
+
318
+ // This is the shape of the entire area covered by a warp's memory access (in units of vectors)
319
+ layout::PitchLinearCoord warp_footprint{
320
+ Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous,
321
+ Detail::WarpThreadArrangement::kStrided * Iterations::kStrided
322
+ };
323
+
324
+ // This is the offset of a specific warp (in units of vectors)
325
+ layout::PitchLinearCoord warp_offset{
326
+ (warp_id % Detail::kWarpsContiguous),
327
+ (warp_id / Detail::kWarpsContiguous)
328
+ };
329
+
330
+ // This is the offset of a specific thread within a warp (units of vectors)
331
+ layout::PitchLinearCoord thread_offset_in_warp{
332
+ lane_id % Detail::WarpThreadArrangement::kContiguous,
333
+ lane_id / Detail::WarpThreadArrangement::kContiguous
334
+ };
335
+
336
+ // This is the offset of a thread within a threadblock tile (units of vectors)
337
+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec =
338
+ warp_footprint * warp_offset + thread_offset_in_warp;
339
+
340
+ // This is the offset of a thread within a threadblock tile (units of elements)
341
+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{
342
+ thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess,
343
+ thread_offset_in_threadblock_tile_vec.strided()
344
+ };
345
+
346
+ return thread_offset_in_threadblock_tile_base;
347
+ }
348
+ };
349
+
350
+ ////////////////////////////////////////////////////////////////////////////////
351
+
352
+ /// Policy defining a warp-raked arrangement in which a shape is partitioned into contiguous
353
+ /// elements. Warps are arranged based on a stride.
354
+ ///
355
+ /// This ThreadMap is used by tensor core kernels for NCxHWx layout.
356
+ template <
357
+ typename Shape_,
358
+ int Threads,
359
+ typename WarpThreadArrangement_,
360
+ int ElementsPerAccess = 1
361
+ >
362
+ struct PitchLinearStridedWarpRakedThreadMap {
363
+
364
+ /// Tensor coordinate
365
+ using TensorCoord = layout::PitchLinearCoord;
366
+
367
+ /// Tile shape
368
+ using Shape = Shape_;
369
+
370
+ /// Number of threads total
371
+ static int const kThreads = Threads;
372
+
373
+ using WarpThreadArrangement = WarpThreadArrangement_;
374
+
375
+ /// Extract vector length from Layout
376
+ static int const kElementsPerAccess = ElementsPerAccess;
377
+
378
+ /// Base ThreadMap
379
+ using BaseThreadMap = PitchLinearWarpRakedThreadMap<
380
+ Shape,
381
+ kThreads,
382
+ WarpThreadArrangement,
383
+ kElementsPerAccess
384
+ >;
385
+
386
+ /// Shape of access by each thread
387
+ using ThreadAccessShape = typename BaseThreadMap::ThreadAccessShape;
388
+
389
+
390
+ struct Detail {
391
+
392
+ using WarpThreadArrangement = WarpThreadArrangement_;
393
+
394
+ using WarpAccessIterations = typename BaseThreadMap::Detail::WarpAccessIterations;
395
+
396
+ static int const kWarpSize = BaseThreadMap::Detail::kWarpSize;
397
+
398
+ static int const kWarpCount = BaseThreadMap::Detail::kWarpCount;
399
+
400
+ using ShapeInAccesses = typename BaseThreadMap::Detail::ShapeInAccesses;
401
+
402
+ // Divide it into the number of warps, first partitioning the contiguous dimension then the
403
+ // stride.
404
+ static int const kWarpsContiguous =
405
+ (WarpAccessIterations::kContiguous >= kWarpCount
406
+ ? kWarpCount
407
+ : WarpAccessIterations::kContiguous);
408
+
409
+ static int const kWarpsStrided =
410
+ (kWarpCount > WarpAccessIterations::kContiguous
411
+ ? kWarpCount / kWarpsContiguous
412
+ : 1);
413
+
414
+ /// Arrangement of warps within a threadblock-scoped tile
415
+ using WarpArrangement = layout::PitchLinearShape<
416
+ kWarpsContiguous, kWarpsStrided
417
+ >;
418
+
419
+ };
420
+
421
+ ///< Iterations along each dimension (concept: PitchLinearShape)
422
+ using Iterations = layout::PitchLinearShape<
423
+ Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous,
424
+ Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided
425
+ >;
426
+
427
+ static_assert(Iterations::kCount,
428
+ "Number of iterations must be non-zero");
429
+
430
+ ///< Delta between accesses (units of elements, concept: PitchLinearShape)
431
+ using Delta = typename BaseThreadMap::Delta;
432
+
433
+ /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space
434
+ CUTLASS_HOST_DEVICE
435
+ static TensorCoord initial_offset(int thread_id) {
436
+
437
+ int warp_id = (thread_id / Detail::kWarpSize);
438
+ int lane_id = (thread_id % Detail::kWarpSize);
439
+
440
+ //
441
+ // compute warp-level offset
442
+ //
443
+
444
+ // This is the shape of the entire area covered by a warp's memory access (in units of vectors)
445
+ layout::PitchLinearCoord warp_footprint{
446
+ Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous,
447
+ Detail::WarpThreadArrangement::kStrided * Iterations::kStrided
448
+ };
449
+
450
+ // This is the offset of a specific warp (in units of vectors)
451
+ layout::PitchLinearCoord warp_offset{
452
+ (warp_id % Detail::kWarpsContiguous),
453
+ (warp_id / Detail::kWarpsContiguous)
454
+ };
455
+
456
+ // This is the offset of a specific thread within a warp (units of vectors)
457
+ layout::PitchLinearCoord thread_offset_in_warp{
458
+ lane_id % Detail::WarpThreadArrangement::kContiguous,
459
+ lane_id / Detail::WarpThreadArrangement::kContiguous
460
+ };
461
+
462
+ // This is the offset of a thread within a threadblock tile (units of vectors)
463
+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec =
464
+ warp_footprint * warp_offset + thread_offset_in_warp;
465
+
466
+ // This is the offset of a thread within a threadblock tile (units of elements)
467
+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{
468
+ thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess,
469
+ thread_offset_in_threadblock_tile_vec.strided()
470
+ };
471
+
472
+ return thread_offset_in_threadblock_tile_base;
473
+ }
474
+
475
+
476
+ };
477
+
478
+ ////////////////////////////////////////////////////////////////////////////////
479
+
480
+ /// Transpose the existing ThreadMap. For example, interleaved layout is like
481
+ /// congruous in the global memory and crosswise in the shared memory. We need
482
+ /// to transpose the coordinates between two.
483
+
484
+ template <typename ThreadMap_, typename WarpThreadArrangement_>
485
+ struct TransposePitchLinearThreadMap {
486
+ /// Underlying ThreadMap
487
+ using ThreadMap = ThreadMap_;
488
+
489
+ /// Tensor coordinate
490
+ using TensorCoord = typename ThreadMap::TensorCoord;
491
+
492
+ /// Tile shape
493
+ using Shape = typename ThreadMap::Shape;
494
+
495
+ /// Number of threads total
496
+ static int const kThreads = ThreadMap::kThreads;
497
+
498
+ /// Extract vector length from Layout
499
+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
500
+
501
+ /// Shape of access by each thread
502
+ using ThreadAccessShape = layout::PitchLinearShape<kElementsPerAccess, 1>;
503
+
504
+ /// Internal details made public to facilitate introspection
505
+ struct Detail {
506
+ /// Fixed arrangement of threads within a warp (units of threads).
507
+ using WarpThreadArrangement = WarpThreadArrangement_;
508
+
509
+ /// Number of threads per warp
510
+ static int const kWarpSize = WarpThreadArrangement::kCount;
511
+
512
+ /// Number of participating warps
513
+ static int const kWarpCount = kThreads / kWarpSize;
514
+
515
+ static_assert(!(Shape::kContiguous % kElementsPerAccess),
516
+ "Shape must be divisible by vector length.");
517
+
518
+ /// Arrangement of warps within a threadblock-scoped tile
519
+ using WarpArrangement =
520
+ layout::PitchLinearShape<ThreadMap::Detail::kWarpsStrided,
521
+ ThreadMap::Detail::kWarpsContiguous>;
522
+ };
523
+
524
+ ///< Iterations along each dimension (concept: PitchLinearShape)
525
+ using Iterations =
526
+ layout::PitchLinearShape<ThreadMap::Iterations::kStrided,
527
+ ThreadMap::Iterations::kContiguous>;
528
+
529
+ static_assert(Iterations::kContiguous == 1,
530
+ "Contiguous iteration has to be one to reuse the same shared store function with those that don't need transpose");
531
+
532
+ static_assert(Iterations::kCount, "Number of iterations must be non-zero");
533
+
534
+ ///< Delta between accesses (units of elements, concept: PitchLinearShape)
535
+ using Delta =
536
+ layout::PitchLinearShape<Detail::WarpThreadArrangement::kContiguous *
537
+ kElementsPerAccess,
538
+ Detail::WarpThreadArrangement::kStrided>;
539
+
540
+ /// Maps thread ID to a coordinate offset within the tensor's logical
541
+ /// coordinate space Note this is slightly different from the one of
542
+ /// PitchLinearWarpRakedThreadMap.
543
+ CUTLASS_HOST_DEVICE
544
+ static TensorCoord initial_offset(int thread_id) {
545
+
546
+ int warp_id = (thread_id / Detail::kWarpSize);
547
+ int lane_id = (thread_id % Detail::kWarpSize);
548
+
549
+ //
550
+ // compute warp-level offset
551
+ //
552
+
553
+ // This is the shape of the entire area covered by a warp's memory access
554
+ // (in units of vectors)
555
+ layout::PitchLinearCoord warp_footprint{
556
+ Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous,
557
+ Detail::WarpThreadArrangement::kStrided * Iterations::kStrided};
558
+
559
+ // This is the offset of a specific warp (in units of vectors)
560
+ // Note the order of / and %. Also the 2nd operand is kStrided.
561
+ layout::PitchLinearCoord warp_offset{
562
+ (warp_id / Detail::WarpArrangement::kStrided),
563
+ (warp_id % Detail::WarpArrangement::kStrided)};
564
+
565
+ // This is the offset of a specific thread within a warp (units of vectors)
566
+ layout::PitchLinearCoord thread_offset_in_warp{
567
+ lane_id % Detail::WarpThreadArrangement::kContiguous,
568
+ lane_id / Detail::WarpThreadArrangement::kContiguous};
569
+
570
+ // This is the offset of a thread within a threadblock tile (units of
571
+ // vectors)
572
+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec =
573
+ warp_footprint * warp_offset + thread_offset_in_warp;
574
+
575
+ // This is the offset of a thread within a threadblock tile (units of
576
+ // elements)
577
+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{
578
+ thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess,
579
+ thread_offset_in_threadblock_tile_vec.strided()};
580
+
581
+ return thread_offset_in_threadblock_tile_base;
582
+ }
583
+ };
584
+
585
+ template <typename ThreadMap_>
586
+ struct TransposePitchLinearThreadMapSimt {
587
+ /// Underlying ThreadMap
588
+ using ThreadMap = ThreadMap_;
589
+
590
+ /// Tensor coordinate
591
+ using TensorCoord = typename ThreadMap::TensorCoord;
592
+
593
+ /// Tile shape
594
+ using Shape = typename ThreadMap::Shape;
595
+
596
+ /// Number of threads total
597
+ static int const kThreads = ThreadMap::kThreads;
598
+
599
+ /// Extract vector length from Layout
600
+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
601
+
602
+ static_assert(kElementsPerAccess == 1 , "Simt transpose requires elements per access to be 1");
603
+ ///< Iterations along each dimension (concept: PitchLinearShape)
604
+ using Iterations =
605
+ layout::PitchLinearShape<ThreadMap::Iterations::kStrided,
606
+ ThreadMap::Iterations::kContiguous>;
607
+
608
+ static_assert(Iterations::kCount, "Number of iterations must be non-zero");
609
+
610
+ static_assert(Iterations::kStrided == 1,
611
+ "Strided iteration has to be one to reuse the same shared store function with those that don't need transpose");
612
+
613
+ /// Shape of access by each thread
614
+ using ThreadAccessShape = typename ThreadMap::ThreadAccessShape;
615
+
616
+ ///< Delta between accesses (units of elements, concept: PitchLinearShape)
617
+ using Delta =
618
+ layout::PitchLinearShape<ThreadMap::Delta::kStrided,
619
+ ThreadMap::Delta::kContiguous>;
620
+
621
+
622
+ /// Maps thread ID to a coordinate offset within the tensor's logical
623
+ /// coordinate space Note this is slightly different from the one of
624
+ /// PitchLinearWarpRakedThreadMap.
625
+ CUTLASS_HOST_DEVICE
626
+ static TensorCoord initial_offset(int thread_id) {
627
+
628
+ TensorCoord coord = ThreadMap::initial_offset(thread_id);
629
+
630
+ return TensorCoord(
631
+ coord.strided(),
632
+ coord.contiguous()
633
+ );
634
+ }
635
+ };
636
+
637
+ ////////////////////////////////////////////////////////////////////////////////
638
+
639
+
640
+ /// Policy defining a warp-striped arrangement. This partitions a tile into vectorized memory
641
+ /// accesses performed by each warp then distributes warps across them. Warps are striped in the
642
+ /// strided dimension and raked across the contiguous dimension.
643
+ template <
644
+ typename Shape_, /// Overall shape to partition in units of elements
645
+ int Threads, /// Number of partiticipation threads
646
+ typename WarpThreadArrangement_, /// Describes the shape of one memory access per warp
647
+ int ElementsPerAccess = 1 /// Number of elements accessed by each thread per memory operation (i.e. vector size)
648
+ >
649
+ struct PitchLinearWarpStripedThreadMap {
650
+
651
+ /// Tensor coordinate
652
+ using TensorCoord = layout::PitchLinearCoord;
653
+
654
+ /// Tile shape
655
+ using Shape = Shape_;
656
+
657
+ /// Number of threads total
658
+ static int const kThreads = Threads;
659
+
660
+ /// Extract vector length from Layout
661
+ static int const kElementsPerAccess = ElementsPerAccess;
662
+
663
+ /// Shape of access by each thread
664
+ using ThreadAccessShape = layout::PitchLinearShape<kElementsPerAccess, 1>;
665
+
666
+ /// Internal details made public to facilitate introspection
667
+ struct Detail {
668
+
669
+ /// Fixed arrangement of threads within a warp (units of threads).
670
+ using WarpThreadArrangement = WarpThreadArrangement_;
671
+
672
+ /// Number of threads per warp
673
+ static int const kWarpSize = WarpThreadArrangement::kCount;
674
+
675
+ /// Number of participating warps
676
+ static int const kWarpCount = kThreads / kWarpSize;
677
+
678
+ static_assert(
679
+ !(Shape::kContiguous % kElementsPerAccess),
680
+ "Shape must be divisible by vector length.");
681
+
682
+ /// Compute the 'shape' of the overall tile in units of vectors
683
+ using ShapeInAccesses = layout::PitchLinearShape<
684
+ Shape::kContiguous / kElementsPerAccess,
685
+ Shape::kStrided
686
+ >;
687
+
688
+ // compute number of warp-level accesses total
689
+ using WarpAccessIterations = layout::PitchLinearShape<
690
+ ShapeInAccesses::kContiguous / WarpThreadArrangement::kContiguous,
691
+ ShapeInAccesses::kStrided / WarpThreadArrangement::kStrided
692
+ >;
693
+
694
+ // Divide it into the number of warps, first partitioning the strided dimension then the
695
+ // contiguous.
696
+ static int const kWarpsStrided =
697
+ (WarpAccessIterations::kStrided >= kWarpCount
698
+ ? kWarpCount : (kWarpCount / WarpAccessIterations::kStrided));
699
+
700
+ static int const kWarpsContiguous =
701
+ (kWarpCount > WarpAccessIterations::kStrided ?
702
+ WarpAccessIterations::kContiguous / kWarpsStrided : 1);
703
+
704
+ /// Arrangement of warps within a threadblock-scoped tile
705
+ using WarpArrangement = layout::PitchLinearShape<
706
+ kWarpsContiguous, kWarpsStrided
707
+ >;
708
+ };
709
+
710
+ ///< Iterations along each dimension (concept: PitchLinearShape)
711
+ using Iterations = layout::PitchLinearShape<
712
+ Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous,
713
+ Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided
714
+ >;
715
+
716
+ static_assert(Iterations::kCount,
717
+ "Number of iterations must be non-zero");
718
+
719
+ ///< Delta between accesses (units of elements, concept: PitchLinearShape)
720
+ using Delta = layout::PitchLinearShape<
721
+ Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess,
722
+ Detail::WarpThreadArrangement::kStrided * Detail::WarpArrangement::kStrided
723
+ >;
724
+
725
+ /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space
726
+ CUTLASS_HOST_DEVICE
727
+ static TensorCoord initial_offset(int thread_id) {
728
+
729
+ int warp_id = (thread_id / Detail::kWarpSize);
730
+ int lane_id = (thread_id % Detail::kWarpSize);
731
+
732
+ //
733
+ // compute warp-level offset
734
+ //
735
+
736
+ // This is the shape of the entire area covered by a warp's memory access (in units of vectors)
737
+ layout::PitchLinearCoord warp_footprint{
738
+ Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous,
739
+ Detail::WarpThreadArrangement::kStrided
740
+ };
741
+
742
+ // This is the offset of a specific warp (in units of vectors)
743
+ layout::PitchLinearCoord warp_offset{
744
+ (warp_id % Detail::kWarpsContiguous),
745
+ (warp_id / Detail::kWarpsContiguous)
746
+ };
747
+
748
+ // This is the offset of a specific thread within a warp (units of vectors)
749
+ layout::PitchLinearCoord thread_offset_in_warp{
750
+ lane_id % Detail::WarpThreadArrangement::kContiguous,
751
+ lane_id / Detail::WarpThreadArrangement::kContiguous
752
+ };
753
+
754
+ // This is the offset of a thread within a threadblock tile (units of vectors)
755
+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec =
756
+ warp_footprint * warp_offset + thread_offset_in_warp;
757
+
758
+ // This is the offset of a thread within a threadblock tile (units of elements)
759
+ layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{
760
+ thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess,
761
+ thread_offset_in_threadblock_tile_vec.strided()
762
+ };
763
+
764
+ return thread_offset_in_threadblock_tile_base;
765
+ }
766
+ };
767
+
768
+ /////////////////////////////////////////////////////////////////////////////////////////////////
769
+ /// Strip-mines a pitch-linear tile among a given number of threads, first along the contiguous
770
+ /// dimension then along the strided dimension, while each thread access a 2D thread-tile.
771
+ ///
772
+ /// The tile must be divisible by the thread count such that all threads may execute the same
773
+ /// number of iterations with the same delta to exhaustively cover the tile.
774
+ ///
775
+ /// This class satisfies the "RegularThreadMapping" concept.
776
+ template <
777
+ typename Shape_,
778
+ int Threads,
779
+ typename ThreadTileShape
780
+ >
781
+ struct PitchLinear2DThreadTileStripminedThreadMap;
782
+
783
+
784
+ template <
785
+ typename Shape_,
786
+ int Threads
787
+ >
788
+ struct PitchLinear2DThreadTileStripminedThreadMap <Shape_, Threads, cutlass::layout::PitchLinearShape<4, 4>>{
789
+
790
+ /// Tensor coordinate
791
+ using TensorCoord = layout::PitchLinearCoord;
792
+
793
+ /// Tile shape
794
+ using Shape = Shape_;
795
+
796
+ /// Access Shape of each thread
797
+ using ThreadAccessShape = cutlass::layout::PitchLinearShape<4, 4>;
798
+ //using ThreadAccessShape = ThreadTileShape;
799
+
800
+ /// Number of threads total
801
+ static int const kThreads = Threads;
802
+
803
+ /// Extract length of each access from Layout
804
+ static int const kElementsPerAccess = ThreadAccessShape::kContiguous;
805
+
806
+ static_assert(!(kElementsPerAccess % 4) , "kElementsPerAccess, needs to be multiple of 4 (32bits)");
807
+
808
+ /// Internal implementation details
809
+ struct Detail {
810
+
811
+ static_assert(!(ThreadAccessShape::kContiguous % 4), "ThreadAccessShape, needs to be multiple of 4");
812
+
813
+ static_assert(!(Shape::kContiguous % ThreadAccessShape::kContiguous), "");
814
+
815
+ static_assert(!((Shape::kContiguous * Shape::kStrided) % (kThreads * ThreadAccessShape::kCount)),
816
+ "Shape must be divisible thread count * accesses per thread.");
817
+
818
+ /// Shape of the tile in units of vectors
819
+ using ShapeVec = layout::PitchLinearShape<
820
+ Shape::kContiguous / ThreadAccessShape::kContiguous,
821
+ Shape::kStrided / ThreadAccessShape::kStrided
822
+ >;
823
+
824
+ static_assert(
825
+ (Threads < ShapeVec::kContiguous && !(ShapeVec::kContiguous % kThreads)) ||
826
+ (!(kThreads % ShapeVec::kContiguous) && !(ShapeVec::kStrided % (kThreads / ShapeVec::kContiguous))),
827
+ "Shape must be divisible by number of iterations of each thread."
828
+ );
829
+ };
830
+
831
+ /// Number of iterations by each thread
832
+ using Iterations = typename platform::conditional<
833
+ Threads >= Detail::ShapeVec::kContiguous,
834
+ layout::PitchLinearShape<
835
+ 1,
836
+ // Redo the comparison here to work around divide by zero compiler
837
+ // error. The compiler evaluates both path of platform::conditional.
838
+ (Threads >= Detail::ShapeVec::kContiguous
839
+ ? Detail::ShapeVec::kStrided /
840
+ (kThreads / Detail::ShapeVec::kContiguous)
841
+ : 0)>,
842
+ layout::PitchLinearShape<Detail::ShapeVec::kContiguous / kThreads,
843
+ Detail::ShapeVec::kStrided>>::type;
844
+
845
+ /// Interval between accesses along each dimension of the tensor's logical coordinate space
846
+ /// (in units of Elements)
847
+ using Delta = typename platform::conditional<
848
+ Threads >= Detail::ShapeVec::kContiguous,
849
+ layout::PitchLinearShape<
850
+ Shape::kContiguous,
851
+ kThreads * ThreadAccessShape::kStrided / Detail::ShapeVec::kContiguous
852
+ >,
853
+ layout::PitchLinearShape<
854
+ kThreads * ThreadAccessShape::kContiguous,
855
+ 1
856
+ >
857
+ >::type;
858
+
859
+ /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space
860
+ /// (in units of Elements)
861
+ CUTLASS_HOST_DEVICE
862
+ static TensorCoord initial_offset(int thread_id) {
863
+
864
+ return TensorCoord(
865
+ (thread_id % Detail::ShapeVec::kContiguous) * ThreadAccessShape::kContiguous,
866
+ (thread_id / Detail::ShapeVec::kContiguous) * ThreadAccessShape::kStrided);
867
+ }
868
+ };
869
+
870
+ /// Thread Mapping a 2D threadtiled mapping as a transposed Pitchlinear2DThreadTile mapping
871
+ template <typename ThreadMap_>
872
+ struct TransposePitchLinearThreadMap2DThreadTile {
873
+ /// Underlying ThreadMap
874
+ using ThreadMap = ThreadMap_;
875
+
876
+ /// Tensor coordinate
877
+ using TensorCoord = typename ThreadMap::TensorCoord;
878
+
879
+ /// Tile shape
880
+ using Shape = typename ThreadMap::Shape;
881
+
882
+ /// Number of threads total
883
+ static int const kThreads = ThreadMap::kThreads;
884
+
885
+ /// Extract vector length from Layout
886
+ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
887
+
888
+
889
+ static_assert(kElementsPerAccess > 1 , "Simt transpose requires elements per access to be 1");
890
+ ///< Iterations along each dimension (concept: PitchLinearShape)
891
+ using Iterations =
892
+ layout::PitchLinearShape<ThreadMap::Iterations::kStrided,
893
+ ThreadMap::Iterations::kContiguous>;
894
+
895
+ static_assert(Iterations::kCount, "Number of iterations must be non-zero");
896
+
897
+ /// Shape of access by each thread
898
+ using ThreadAccessShape = typename ThreadMap::ThreadAccessShape;
899
+
900
+ ///< Delta between accesses (units of elements, concept: PitchLinearShape)
901
+ using Delta =
902
+ layout::PitchLinearShape<ThreadMap::Delta::kStrided,
903
+ ThreadMap::Delta::kContiguous>;
904
+
905
+
906
+ /// Maps thread ID to a coordinate offset within the tensor's logical
907
+ /// coordinate space Note this is slightly different from the one of
908
+ /// PitchLinearWarpRakedThreadMap.
909
+ CUTLASS_HOST_DEVICE
910
+ static TensorCoord initial_offset(int thread_id) {
911
+
912
+ TensorCoord coord = ThreadMap::initial_offset(thread_id);
913
+ return TensorCoord(
914
+ coord.strided(),
915
+ coord.contiguous()
916
+ );
917
+ }
918
+ };
919
+
920
+
921
+ /////////////////////////////////////////////////////////////////////////////////////////////////
922
+
923
+ } // namespace transform
924
+ } // namespace cutlass
925
+
926
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/thread/transpose.h ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Basic copy routines for tensor views
34
+ */
35
+
36
+ #pragma once
37
+
38
+ namespace cutlass {
39
+ namespace transform {
40
+ namespace thread {
41
+
42
+ /// Transforms a fragment by doing a transpose
43
+ template <
44
+ int ElementCount,
45
+ typename TransposeShape,
46
+ typename Element
47
+ > struct Transpose;
48
+
49
+ /// Specialization for int8_t 4x4 transpose
50
+ template <int ElementCount_>
51
+ struct Transpose<ElementCount_, layout::PitchLinearShape<4,4> , int8_t> {
52
+
53
+ static const int kElementCount = ElementCount_;
54
+ using TransposeShape = layout::PitchLinearShape<4,4>;
55
+ using Element = int8_t;
56
+ using Fragment = cutlass::Array<Element, kElementCount>;
57
+
58
+ static_assert(!(kElementCount % TransposeShape::kCount), "Shape needs to be multiple of 16 elements to do a 4x4 transpose");
59
+
60
+ CUTLASS_DEVICE
61
+ void transform(Fragment& dst, Fragment& src) {
62
+
63
+ // Expose src/dst as int arrays.
64
+ int* src_int = reinterpret_cast<int*>(&src);
65
+ int* dst_int = reinterpret_cast<int*>(&dst);
66
+
67
+ CUTLASS_PRAGMA_UNROLL
68
+ for (int i = 0; i < kElementCount / TransposeShape::kCount; i++){
69
+
70
+ int const i0 = 4 * i + 0;
71
+ int const i1 = 4 * i + 1;
72
+ int const i2 = 4 * i + 2;
73
+ int const i3 = 4 * i + 3;
74
+
75
+ int a0 = src_int[i0];
76
+ int a1 = src_int[i1];
77
+ int a2 = src_int[i2];
78
+ int a3 = src_int[i3];
79
+
80
+ int b0, b1, b2, b3, c0;
81
+ b0 = __byte_perm(a0, a1, 0x0040);
82
+ c0 = __byte_perm(a2, a3, 0x0040);
83
+ b0 = __byte_perm(b0, c0, 0x5410);
84
+
85
+ b1 = __byte_perm(a0, a1, 0x0051);
86
+ c0 = __byte_perm(a2, a3, 0x0051);
87
+ b1 = __byte_perm(b1, c0, 0x5410);
88
+
89
+ b2 = __byte_perm(a0, a1, 0x0062);
90
+ c0 = __byte_perm(a2, a3, 0x0062);
91
+ b2 = __byte_perm(b2, c0, 0x5410);
92
+
93
+ b3 = __byte_perm(a0, a1, 0x0073);
94
+ c0 = __byte_perm(a2, a3, 0x0073);
95
+ b3 = __byte_perm(b3, c0, 0x5410);
96
+
97
+ dst_int[i0] = b0;
98
+ dst_int[i1] = b1;
99
+ dst_int[i2] = b2;
100
+ dst_int[i3] = b3;
101
+ }
102
+ }
103
+ };
104
+
105
+ } // namespace thread
106
+ } // namespace layout
107
+ } // namespace cutlass
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/thread/unary_op.h ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ #pragma once
32
+
33
+ #include "cutlass/cutlass.h"
34
+ #include "cutlass/complex.h"
35
+
36
+ namespace cutlass {
37
+ namespace transform {
38
+ namespace thread {
39
+
40
+ namespace UnaryTransform {
41
+ struct Identity; ///< None (i.e., identity)
42
+ struct Conjugate; ///< Complex conjugate
43
+ }
44
+
45
+ /// Element-wise unary operator that transforms one element of a fragment at a time
46
+ template<
47
+ typename FragmentIn, ///< Input Fragment
48
+ typename FragmentOut,///< Output Fragment
49
+ typename Transform> ///< Unary transform operator
50
+ class UnaryOp
51
+ {
52
+ public:
53
+ CUTLASS_DEVICE
54
+ static FragmentOut execute(FragmentIn &in)
55
+ {
56
+ static_assert(FragmentIn::kElements == FragmentOut::kElements, "Number of elements must match.");
57
+ static_assert(platform::is_same<Transform, UnaryTransform::Identity>::value ||
58
+ platform::is_same<Transform, UnaryTransform::Conjugate>::value,
59
+ "Unary Operator not supported.");
60
+
61
+ FragmentOut out;
62
+ if (platform::is_same<Transform, UnaryTransform::Identity>::value )
63
+ {
64
+ CUTLASS_PRAGMA_UNROLL
65
+ for (int i=0; i < FragmentIn::kElements; ++i){
66
+ out[i] = static_cast<typename FragmentOut::Element>(in[i]);
67
+ }
68
+ }
69
+ else if (platform::is_same<Transform, UnaryTransform::Conjugate>::value )
70
+ {
71
+ for (int i=0; i < FragmentIn::kElements; ++i){
72
+ out[i] = conj(static_cast<typename FragmentOut::Element>(in[i]));
73
+ }
74
+ }
75
+ return out;
76
+ }
77
+ };
78
+
79
+ template<typename FragmentIn, typename Transform>
80
+ class UnaryOp<FragmentIn, FragmentIn, Transform>
81
+ {
82
+ public:
83
+ CUTLASS_DEVICE
84
+ static FragmentIn execute(FragmentIn &in)
85
+ {
86
+ static_assert(platform::is_same<Transform, UnaryTransform::Identity>::value ||
87
+ platform::is_same<Transform, UnaryTransform::Conjugate>::value,
88
+ "Unary Operator not supported.");
89
+
90
+ if (platform::is_same<Transform, UnaryTransform::Identity>::value )
91
+ {
92
+ return in;
93
+ }
94
+ else if (platform::is_same<Transform, UnaryTransform::Conjugate>::value )
95
+ {
96
+ for(int i=0; i < FragmentIn::kElements; ++i){
97
+ in[i] = conj(in[i]);
98
+ }
99
+ }
100
+ return in;
101
+ }
102
+ };
103
+ }
104
+ }
105
+ }
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_iterator.h ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Ell iterator for matrix of indices (ellColInd matrix)
33
+ */
34
+
35
+ #pragma once
36
+
37
+ namespace cutlass {
38
+ namespace transform {
39
+ namespace threadblock {
40
+
41
+ namespace ell{
42
+
43
+ constexpr unsigned int SmemPow = 8;
44
+ constexpr unsigned int SmemStages = 2;
45
+ constexpr unsigned int SmemSize = 1 << SmemPow;
46
+ constexpr unsigned int SmemMask = (SmemSize*SmemStages-1);
47
+
48
+ class SharedStorage{
49
+ public:
50
+ Array<int, SmemSize*SmemStages> array;
51
+ };
52
+
53
+ class Iterator{
54
+ public:
55
+ using Layout = layout::PitchLinear;
56
+ using LongIndex = typename Layout::LongIndex;
57
+
58
+ private:
59
+ const int *gmem_col_idx_;
60
+ int *smem_col_idx_;
61
+ const int block_size_;
62
+ const int base_idx_;
63
+ const int k_shape_;
64
+ const int ell_increment_;
65
+ const int array_length_;
66
+ int col_idx_base_;
67
+ int residue_;
68
+ int counter_;
69
+
70
+ int pow2_;
71
+ int residue_shape_;
72
+
73
+ int smem_offset_;
74
+ int smem_stage_;
75
+ int gmem_offset_;
76
+
77
+ int lane_;
78
+
79
+ bool is_pow2_;
80
+ bool is_residue_tile_;
81
+
82
+ public:
83
+ CUTLASS_DEVICE
84
+ void load_ell_indices(){
85
+ for(int i=threadIdx.x; i<SmemSize; i+=blockDim.x){
86
+ int idx = (gmem_offset_+i < array_length_) ? gmem_offset_+i : array_length_-1;
87
+ int gmem_col_idx = gmem_col_idx_[idx] - base_idx_;
88
+ smem_col_idx_[i + smem_stage_ * SmemSize] =
89
+ (gmem_col_idx >= 0) ? gmem_col_idx : -1;
90
+ }
91
+ gmem_offset_ += SmemSize;
92
+ smem_stage_ ^= 1;
93
+ }
94
+
95
+ CUTLASS_DEVICE
96
+ Iterator(
97
+ SharedStorage& shared_storage_base,
98
+ const int* col_idx,
99
+ const int& block_size,
100
+ const int& base_idx,
101
+ const int k_shape,
102
+ const int& problem_size_k,
103
+ const int& ell_stride,
104
+ const int& thread_idx)
105
+ : residue_(0),
106
+ counter_(0),
107
+ smem_offset_(0),
108
+ smem_stage_(0),
109
+ gmem_offset_(0),
110
+ block_size_(block_size),
111
+ base_idx_(base_idx),
112
+ k_shape_(k_shape),
113
+ ell_increment_(ell_stride * block_size),
114
+ array_length_((problem_size_k + block_size_ - 1) / block_size_),
115
+ residue_shape_(problem_size_k % k_shape_),
116
+ is_residue_tile_(residue_shape_ != 0),
117
+ smem_col_idx_(reinterpret_cast<int*>(&shared_storage_base.array)),
118
+ gmem_col_idx_(const_cast<int*>(col_idx)),
119
+ lane_(thread_idx % 32) {
120
+
121
+ load_ell_indices();
122
+ __syncthreads();
123
+
124
+ is_pow2_ = ((block_size_ & (block_size_ - 1)) == 0);
125
+ if( is_pow2_ && k_shape <= block_size_ ) lane_ = 0;
126
+
127
+ col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_;
128
+
129
+ pow2_ = 0;
130
+ while(block_size_ >> (pow2_ + 1)) ++pow2_;
131
+ }
132
+
133
+ CUTLASS_DEVICE
134
+ int get_blocksize(){
135
+ return block_size_;
136
+ }
137
+
138
+ CUTLASS_DEVICE
139
+ Iterator &operator++(){
140
+ if(is_residue_tile_){
141
+ residue_ += residue_shape_;
142
+ is_residue_tile_ = false;
143
+ } else {
144
+ residue_ += k_shape_;
145
+ }
146
+
147
+ if(residue_ < block_size_){
148
+ return *this;
149
+ }
150
+
151
+ if((array_length_ > SmemSize) && (((smem_offset_ >> SmemPow) & 1) != smem_stage_))
152
+ load_ell_indices();
153
+
154
+ if(residue_ == block_size_){
155
+ ++smem_offset_;
156
+ counter_ += ell_increment_;
157
+ residue_ = 0;
158
+ col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_ - counter_;
159
+ return *this;
160
+ }
161
+
162
+ if(is_pow2_){
163
+ smem_offset_ += residue_ >> pow2_;
164
+ counter_ += (residue_ >> pow2_) * ell_increment_;
165
+ residue_ = residue_ & ((1 << pow2_) - 1);
166
+ }
167
+ else {
168
+ smem_offset_ += residue_ / block_size_;
169
+ counter_ += (residue_ / block_size_) * ell_increment_;
170
+ residue_ %= block_size_;
171
+ }
172
+
173
+ col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_ - counter_;
174
+
175
+ return *this;
176
+ }
177
+
178
+ CUTLASS_DEVICE
179
+ LongIndex get_offset(const int& idx) {
180
+ int num_jump_tiles;
181
+ if(is_pow2_)
182
+ num_jump_tiles = (idx + residue_) >> pow2_;
183
+ else
184
+ num_jump_tiles = (idx + residue_) / block_size_;
185
+
186
+ int tmp = __shfl_sync(0xffffffff, col_idx_base_, num_jump_tiles);
187
+ return tmp - num_jump_tiles * ell_increment_;
188
+ }
189
+
190
+ CUTLASS_DEVICE
191
+ LongIndex get_offset_fast() {
192
+ return col_idx_base_;
193
+ }
194
+ };
195
+
196
+ }
197
+ }
198
+ }
199
+ }
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h ADDED
@@ -0,0 +1,1350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Ell iterator for Blocked-Ell matrix (ellValue matrix) used with EllMmaMultistage
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/array.h"
38
+ #include "cutlass/coord.h"
39
+ #include "cutlass/cutlass.h"
40
+ #include "cutlass/layout/matrix.h"
41
+ #include "cutlass/layout/pitch_linear.h"
42
+ #include "cutlass/matrix_shape.h"
43
+ #include "cutlass/predicate_vector.h"
44
+ #include "cutlass/tensor_ref.h"
45
+ #include "cutlass/tensor_view.h"
46
+
47
+ ////////////////////////////////////////////////////////////////////////////////
48
+
49
+ ////////////////////////////////////////////////////////////////////////////////
50
+
51
+ namespace cutlass {
52
+ namespace transform {
53
+ namespace threadblock {
54
+
55
+ ////////////////////////////////////////////////////////////////////////////////
56
+
57
+ /// EllPredicatedTileAccessIterator
58
+ ///
59
+ template <typename Shape, typename Element, typename Layout, int AdvanceRank,
60
+ typename ThreadMap, typename AccessType>
61
+ class EllPredicatedTileAccessIterator;
62
+
63
+ ////////////////////////////////////////////////////////////////////////////////
64
+
65
+ /// Specialization of EllPredicatedTileAccessIterator for pitch-linear data.
66
+ ///
67
+ template <typename Shape_, typename Element_, int AdvanceRank,
68
+ typename ThreadMap_, typename AccessType_>
69
+ class EllPredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
70
+ AdvanceRank, ThreadMap_, AccessType_> {
71
+ public:
72
+ static_assert(
73
+ AdvanceRank == 0 || AdvanceRank == 1,
74
+ "Specialization for pitch-linear iterator may along advance along the "
75
+ "contiguous(rank=0) or strided(rank=1) dimension.");
76
+
77
+ using Shape = Shape_;
78
+ using Element = Element_;
79
+ using Layout = layout::PitchLinear;
80
+ static int const kAdvanceRank = AdvanceRank;
81
+ using ThreadMap = ThreadMap_;
82
+ using AccessType = AccessType_;
83
+
84
+ using Index = typename Layout::Index;
85
+ using LongIndex = typename Layout::LongIndex;
86
+
87
+ using TensorRef = TensorRef<Element, Layout>;
88
+ using TensorView = TensorView<Element, Layout>;
89
+ using TensorCoord = typename Layout::TensorCoord;
90
+
91
+ using Pointer = Element *;
92
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
93
+
94
+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
95
+
96
+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
97
+ "Vectors implied by the thread map must be divisible by the access type.");
98
+
99
+ static int const kPredicatesPerByte = 4;
100
+ static int const kPredicatesPerWord = 4 * kPredicatesPerByte;
101
+
102
+ static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector;
103
+
104
+ /// Number of 32b words containing predicates
105
+ static int const kPredicateByteCount =
106
+ (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte;
107
+ static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4;
108
+
109
+ static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u;
110
+
111
+ static_assert(kPredicateWordCount <= 4, "Too many predicates.");
112
+
113
+ /// Predicate vector stores mask to guard accesses
114
+ using Mask = Array<uint32_t, kPredicateWordCount>;
115
+
116
+ /// Parameters object is precomputed state and is host-constructible
117
+ class Params {
118
+ public:
119
+ friend EllPredicatedTileAccessIterator;
120
+
121
+ private:
122
+ /// stride of pitch-linear layout (units of Element)
123
+ LongIndex stride_;
124
+ /// amount (in byte) to increment pointer to move to next access along
125
+ /// strided dimension
126
+ LongIndex inc_strided_;
127
+ /// amount (in byte) to increment pointer from last access to first access
128
+ /// of next tile
129
+ LongIndex inc_next_;
130
+ /// amount (in byte) to increment pointer from first access of current tile
131
+ /// to first access of next tile
132
+ LongIndex inc_advance_;
133
+
134
+ public:
135
+
136
+ // Default ctor
137
+ CUTLASS_HOST_DEVICE
138
+ Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { }
139
+
140
+ /// Construct the Params object given a pitch-linear tensor's layout
141
+ CUTLASS_HOST_DEVICE
142
+ Params(Layout const &layout) : stride_(layout.stride(0)) {
143
+ inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) *
144
+ sizeof_bits<Element>::value / 8;
145
+
146
+ if (kAdvanceRank) {
147
+ // advance along strided dimension
148
+ inc_advance_ =
149
+ Shape::kStrided * LongIndex(stride_) * sizeof_bits<Element>::value / 8;
150
+ } else {
151
+ // advance along contiguous dimension
152
+ inc_advance_ = Shape::kContiguous * sizeof_bits<Element>::value / 8;
153
+ }
154
+
155
+ inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) *
156
+ ThreadMap::Delta::kStrided * LongIndex(stride_) *
157
+ sizeof_bits<Element>::value / 8;
158
+ };
159
+ };
160
+
161
+ private:
162
+ /// Internal pointer type permits fast address arithmetic
163
+ using BytePointer = char *;
164
+
165
+ private:
166
+ //
167
+ // Data members
168
+ //
169
+
170
+ /// Parameters object with precomputed internal state
171
+ Params const &params_;
172
+
173
+ /// Internal pointer to first access of tile
174
+ BytePointer pointer_;
175
+
176
+ /// Guard predicates
177
+ uint32_t predicates_[kPredicateWordCount];
178
+
179
+ /// Size of tensor
180
+ TensorCoord extent_;
181
+
182
+ /// Initial offset for each thread
183
+ TensorCoord thread_offset_;
184
+
185
+ /// Offset to the first steady-state tile
186
+ TensorCoord residue_offset_;
187
+
188
+ /// Initial offset to define ELL block
189
+ TensorCoord ell_offset_;
190
+
191
+ /// Used for out-of-order visitation
192
+ bool is_residue_tile_;
193
+
194
+ /// Iteration along vectors implied by the thread map
195
+ int iteration_vector_;
196
+
197
+ /// Iteration in the contiguous dimension
198
+ int iteration_contiguous_;
199
+
200
+ /// Iteration in the strided dimension
201
+ int iteration_strided_;
202
+
203
+ public:
204
+ /// Computes predicates based on internally tracked per-thread offset.
205
+ CUTLASS_DEVICE
206
+ void compute_predicates_(
207
+ /// Extent of the matrix window
208
+ TensorCoord extent,
209
+ /// optionally, simplify predicate calculation during 'steady state' phase
210
+ bool is_steady_state = false) {
211
+
212
+ CUTLASS_PRAGMA_UNROLL
213
+ for (int i = 0; i < kPredicateWordCount; ++i) {
214
+ predicates_[i] = 0u;
215
+ }
216
+
217
+ CUTLASS_PRAGMA_UNROLL
218
+ for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) {
219
+
220
+ int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
221
+
222
+ int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
223
+
224
+ int c = access_residual / kAccessesPerVector;
225
+ int v = access_residual % kAccessesPerVector;
226
+
227
+ TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements,
228
+ s * ThreadMap::Delta::kStrided);
229
+
230
+ TensorCoord coord = thread_offset_ + iteration_coord;
231
+
232
+ bool guard;
233
+
234
+ if (is_steady_state) {
235
+ if (kAdvanceRank == 0) {
236
+ guard = (coord.strided() < extent.strided());
237
+ } else {
238
+ guard = (coord.contiguous() < extent.contiguous());
239
+ }
240
+ } else {
241
+ guard = (coord.strided() < extent.strided() &&
242
+ coord.contiguous() < extent.contiguous());
243
+ }
244
+
245
+ int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s);
246
+
247
+ int word_idx = pred_idx / kPredicatesPerWord;
248
+ int residual = pred_idx % kPredicatesPerWord;
249
+ int byte_idx = residual / kPredicatesPerByte;
250
+ int bit_idx = residual % kPredicatesPerByte;
251
+
252
+ predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx));
253
+
254
+ }
255
+
256
+ }
257
+
258
+ public:
259
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
260
+ /// and thread ID
261
+ CUTLASS_HOST_DEVICE
262
+ EllPredicatedTileAccessIterator(
263
+ /// Precomputed parameters object
264
+ Params const &params,
265
+ /// Pointer to start of tensor
266
+ Pointer pointer,
267
+ /// Extent of tensor
268
+ TensorCoord extent,
269
+ /// ID of each participating thread
270
+ int thread_id,
271
+ /// Initial offset of threadblock
272
+ TensorCoord const &threadblock_offset)
273
+ : params_(params),
274
+ pointer_(reinterpret_cast<BytePointer>(
275
+ const_cast<NonConstPointer>(pointer))),
276
+ extent_(extent),
277
+ is_residue_tile_(true) {
278
+
279
+ TensorCoord residue_extent;
280
+ if (kAdvanceRank) {
281
+
282
+ typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided;
283
+ if (!residue_size) {
284
+ residue_size = Shape::kStrided;
285
+ }
286
+
287
+ residue_offset_ = make_Coord(0, residue_size);
288
+ residue_extent = make_Coord(
289
+ extent_.contiguous(),
290
+ min(threadblock_offset.strided() + residue_size, extent_.strided())
291
+ );
292
+ } else {
293
+
294
+ typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous;
295
+ if (!residue_size) {
296
+ residue_size = Shape::kContiguous;
297
+ }
298
+
299
+ residue_offset_ = make_Coord(residue_size, 0);
300
+
301
+ residue_extent = make_Coord(
302
+ min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size),
303
+ extent_.strided()
304
+ );
305
+ }
306
+
307
+ // Per-thread offset in logical coordinates of tensor
308
+ ell_offset_ = ThreadMap::initial_offset(thread_id);
309
+ thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id);
310
+
311
+ // update internal pointers
312
+ Layout layout(params_.stride_);
313
+ add_pointer_offset(layout(thread_offset_));
314
+
315
+ compute_predicates_(residue_extent, false);
316
+
317
+ set_iteration_index(0);
318
+ }
319
+
320
+ /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset
321
+ CUTLASS_HOST_DEVICE
322
+ EllPredicatedTileAccessIterator(
323
+ /// Precomputed parameters object
324
+ Params const &params,
325
+ /// Pointer to start of tensor
326
+ Pointer pointer,
327
+ /// Extent of tensor
328
+ TensorCoord extent,
329
+ ///< ID of each participating thread
330
+ int thread_id)
331
+ : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id,
332
+ make_Coord(0, 0)) {}
333
+
334
+ /// Overrides the internal iteration index
335
+ CUTLASS_HOST_DEVICE
336
+ void set_iteration_index(int index) {
337
+
338
+ iteration_vector_ = index % kAccessesPerVector;
339
+ int residual_access = index / kAccessesPerVector;
340
+
341
+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
342
+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
343
+
344
+ }
345
+
346
+ /// Adds a pointer offset in units of Element
347
+ CUTLASS_HOST_DEVICE
348
+ void add_pointer_offset(LongIndex pointer_offset) {
349
+ pointer_ += sizeof_bits<Element>::value * pointer_offset / 8;
350
+ }
351
+
352
+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles
353
+ CUTLASS_DEVICE
354
+ void add_tile_offset(
355
+ TensorCoord const &tile_offset) {
356
+ if (is_residue_tile_) {
357
+
358
+ thread_offset_ += residue_offset_;
359
+
360
+ Layout layout(params_.stride_);
361
+ add_pointer_offset(layout(residue_offset_));
362
+
363
+ compute_predicates_(extent_, true);
364
+
365
+ if (kAdvanceRank) {
366
+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1);
367
+ pointer_ += Shape::kContiguous * tile_offset.contiguous();
368
+ } else {
369
+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1);
370
+ pointer_ += Shape::kStrided * tile_offset.strided();
371
+ }
372
+ } else {
373
+ if (kAdvanceRank) {
374
+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided());
375
+ pointer_ += Shape::kContiguous * tile_offset.contiguous();
376
+ } else {
377
+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous());
378
+ pointer_ += Shape::kStrided * tile_offset.strided();
379
+ }
380
+ }
381
+ is_residue_tile_ = false;
382
+ }
383
+
384
+ /// Returns a pointer
385
+ CUTLASS_HOST_DEVICE
386
+ AccessType *get() const {
387
+ return reinterpret_cast<AccessType *>(
388
+ pointer_ +
389
+ iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value) / 8) + iteration_vector_;
390
+ }
391
+
392
+ /// Returns a k_location
393
+ CUTLASS_HOST_DEVICE
394
+ int get_k() const {
395
+ if(kAdvanceRank){ //strided
396
+ return ell_offset_.strided() + iteration_strided_ * ThreadMap::Delta::kStrided;
397
+ }else{
398
+ return ell_offset_.contiguous() + iteration_contiguous_ * ThreadMap::Delta::kContiguous + iteration_vector_ * AccessType::kElements;
399
+ }
400
+ }
401
+
402
+ CUTLASS_HOST_DEVICE
403
+ int get_stride() const {
404
+ if(kAdvanceRank)
405
+ return params_.stride_;
406
+ else
407
+ return 1;
408
+ }
409
+
410
+ /// Increment and return an instance to self.
411
+ CUTLASS_HOST_DEVICE
412
+ EllPredicatedTileAccessIterator &operator++() {
413
+
414
+ ++iteration_vector_;
415
+ if (iteration_vector_ < kAccessesPerVector) {
416
+ return *this;
417
+ }
418
+
419
+ iteration_vector_ = 0;
420
+ ++iteration_contiguous_;
421
+
422
+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
423
+ return *this;
424
+ }
425
+
426
+ // Enter here only if (iteration_contiguous_ ==
427
+ // ThreadMap::Iteration::kContiguous)
428
+ iteration_contiguous_ = 0;
429
+ ++iteration_strided_;
430
+
431
+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
432
+ pointer_ += params_.inc_strided_;
433
+ return *this;
434
+ }
435
+
436
+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
437
+ // which means we enter the next tile.
438
+ iteration_strided_ = 0;
439
+
440
+ // advance to next tile
441
+ pointer_ += params_.inc_next_;
442
+
443
+ // now return to start tile - if the iterator is subsequently advanced, this
444
+ // subtraction as well as the subsequent integer addition are both elided by
445
+ // the compiler.
446
+ pointer_ -= params_.inc_advance_;
447
+
448
+ return *this;
449
+ }
450
+
451
+ /// Increment and return an instance to self.
452
+ CUTLASS_HOST_DEVICE
453
+ EllPredicatedTileAccessIterator operator++(int) {
454
+ EllPredicatedTileAccessIterator self(*this);
455
+ operator++();
456
+ return self;
457
+ }
458
+
459
+ /// Clears the predicate set efficiently
460
+ CUTLASS_HOST_DEVICE
461
+ void clear_mask(bool enable = true) {
462
+ CUTLASS_PRAGMA_UNROLL
463
+ for (int i = 0; i < kPredicateWordCount; ++i) {
464
+ predicates_[i] = enable ? 0u : predicates_[i];
465
+ }
466
+
467
+ }
468
+
469
+ /// Clears the predicate set efficiently
470
+ CUTLASS_HOST_DEVICE
471
+ void enable_mask() {
472
+ CUTLASS_PRAGMA_UNROLL
473
+ for (int i = 0; i < kPredicateWordCount; ++i) {
474
+ predicates_[i] = 0xffffffff;
475
+ }
476
+
477
+ }
478
+
479
+ /// Sets the predicate mask, overriding value stored in predicate iterator
480
+ CUTLASS_HOST_DEVICE
481
+ void set_mask(Mask const &mask) {
482
+ CUTLASS_PRAGMA_UNROLL
483
+ for (int i = 0; i < kPredicateWordCount; ++i) {
484
+ predicates_[i] = mask[i];
485
+ }
486
+
487
+ }
488
+
489
+ /// Gets the mask
490
+ CUTLASS_HOST_DEVICE
491
+ void get_mask(Mask &mask) {
492
+ CUTLASS_PRAGMA_UNROLL
493
+ for (int i = 0; i < kPredicateWordCount; ++i) {
494
+ mask[i] = predicates_[i];
495
+ }
496
+ }
497
+
498
+ /// add mask for small tiles in ELL
499
+ CUTLASS_DEVICE
500
+ void ell_add_mask(int blocksize) {
501
+
502
+ Mask mask;
503
+
504
+ CUTLASS_PRAGMA_UNROLL
505
+ for (int i = 0; i < kPredicateWordCount; ++i) {
506
+ mask[i] = 0u;
507
+ }
508
+
509
+ CUTLASS_PRAGMA_UNROLL
510
+ for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) {
511
+
512
+ int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
513
+
514
+ int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
515
+
516
+ int c = access_residual / kAccessesPerVector;
517
+ int v = access_residual % kAccessesPerVector;
518
+
519
+ TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements,
520
+ s * ThreadMap::Delta::kStrided);
521
+
522
+ TensorCoord coord = ell_offset_ + iteration_coord;
523
+
524
+ bool guard;
525
+
526
+ if (kAdvanceRank == 0) {
527
+ guard = (coord.strided() < blocksize);
528
+ } else {
529
+ guard = (coord.contiguous() < blocksize);
530
+ }
531
+
532
+ int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s);
533
+
534
+ int word_idx = pred_idx / kPredicatesPerWord;
535
+ int residual = pred_idx % kPredicatesPerWord;
536
+ int byte_idx = residual / kPredicatesPerByte;
537
+ int bit_idx = residual % kPredicatesPerByte;
538
+
539
+ mask[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx));
540
+
541
+ }
542
+
543
+ CUTLASS_PRAGMA_UNROLL
544
+ for (int i = 0; i < kPredicateWordCount; ++i) {
545
+ mask[i] &= predicates_[i];
546
+ }
547
+ set_mask(mask);
548
+ }
549
+
550
+ /// Returns whether access is valid or not
551
+ CUTLASS_HOST_DEVICE
552
+ bool valid() {
553
+
554
+ int pred_idx =
555
+ iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous);
556
+
557
+ int word_idx = pred_idx / kPredicatesPerWord;
558
+ int residual = pred_idx % kPredicatesPerWord;
559
+ int byte_idx = residual / kPredicatesPerByte;
560
+ int bit_idx = residual % kPredicatesPerByte;
561
+
562
+ bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0;
563
+ return pred;
564
+
565
+ }
566
+ };
567
+
568
+ ////////////////////////////////////////////////////////////////////////////////
569
+
570
+ /// Specialization of EllPredicatedTileAccessIterator for pitch-linear data.
571
+ ///
572
+ /// Satisfies: ForwardTileIteratorConcept |
573
+ /// ReadableContiguousTileIteratorConcept |
574
+ /// WriteableContiguousTileIteratorConcept |
575
+ /// MaskedTileIteratorConcept
576
+ ///
577
+ template <typename Shape_, typename Element_, int AdvanceRank,
578
+ typename ThreadMap_, typename AccessType_>
579
+ class EllPredicatedTileAccessIterator<Shape_, Element_, layout::ColumnMajor,
580
+ AdvanceRank, ThreadMap_, AccessType_> {
581
+ public:
582
+ static_assert(
583
+ AdvanceRank == 0 || AdvanceRank == 1,
584
+ "Specialization for pitch-linear iterator may along advance along the "
585
+ "contiguous(rank=0) or strided(rank=1) dimension.");
586
+
587
+ using Shape = Shape_;
588
+ using Element = Element_;
589
+ using Layout = layout::ColumnMajor;
590
+ static int const kAdvanceRank = AdvanceRank;
591
+ using ThreadMap = ThreadMap_;
592
+ using AccessType = AccessType_;
593
+
594
+ using Index = typename Layout::Index;
595
+ using LongIndex = typename Layout::LongIndex;
596
+
597
+ using TensorRef = TensorRef<Element, Layout>;
598
+ using TensorView = TensorView<Element, Layout>;
599
+ using TensorCoord = typename Layout::TensorCoord;
600
+
601
+ using Pointer = Element *;
602
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
603
+
604
+ using UnderlyingIterator = EllPredicatedTileAccessIterator<
605
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
606
+ layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>;
607
+
608
+ /// Predicate vector stores mask to guard accesses
609
+ using Mask = typename UnderlyingIterator::Mask;
610
+
611
+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
612
+
613
+ /// Parameters object is precomputed state and is host-constructible
614
+ class Params {
615
+ private:
616
+ friend EllPredicatedTileAccessIterator;
617
+
618
+ /// Parameters object
619
+ typename UnderlyingIterator::Params params_;
620
+
621
+ public:
622
+
623
+ /// Default ctor
624
+ CUTLASS_HOST_DEVICE
625
+ Params() { }
626
+
627
+ /// Construct the Params object given a pitch-linear tensor's layout
628
+ CUTLASS_HOST_DEVICE
629
+ Params(Layout const &layout)
630
+ : params_(layout::PitchLinear(layout.stride(0))){};
631
+ };
632
+
633
+ private:
634
+ //
635
+ // Data members
636
+ //
637
+
638
+ /// Underlying pitch-linear tile iterator
639
+ UnderlyingIterator iterator_;
640
+
641
+ public:
642
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
643
+ /// and thread ID
644
+ CUTLASS_HOST_DEVICE
645
+ EllPredicatedTileAccessIterator(
646
+ ///< Precomputed parameters object
647
+ Params const &params,
648
+ ///< Pointer to start of tensor
649
+ Pointer pointer,
650
+ ///< Extent of tensor
651
+ TensorCoord extent,
652
+ ///< ID of each participating thread
653
+ int thread_id,
654
+ ///< Initial offset of threadblock
655
+ TensorCoord const &threadblock_offset)
656
+ : iterator_(params.params_, pointer,
657
+ layout::PitchLinearCoord(extent.row(), extent.column()),
658
+ thread_id,
659
+ layout::PitchLinearCoord(threadblock_offset.row(),
660
+ threadblock_offset.column())) {}
661
+
662
+ /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset
663
+ CUTLASS_HOST_DEVICE
664
+ EllPredicatedTileAccessIterator(
665
+ Params const &params, ///< Precomputed parameters object
666
+ Pointer pointer, ///< Pointer to start of tensor
667
+ TensorCoord extent, ///< Extent of tensor
668
+ int thread_id ///< ID of each participating thread
669
+ )
670
+ : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id,
671
+ make_Coord(0, 0)) {}
672
+
673
+ /// Overrides the internal iteration index
674
+ CUTLASS_HOST_DEVICE
675
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
676
+
677
+ /// Adds a pointer offset in units of Element
678
+ CUTLASS_HOST_DEVICE
679
+ void add_pointer_offset(LongIndex pointer_offset) {
680
+ iterator_.add_pointer_offset(pointer_offset);
681
+ }
682
+
683
+ /// Advances an iterator along logical dimensions of matrix in units of whole
684
+ /// tiles
685
+ CUTLASS_HOST_DEVICE
686
+ void add_tile_offset(TensorCoord const &tile_offset) {
687
+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
688
+ }
689
+
690
+ /// Returns a pointer
691
+ CUTLASS_HOST_DEVICE
692
+ AccessType *get() const {
693
+ return reinterpret_cast<AccessType *>(iterator_.get());
694
+ }
695
+
696
+ CUTLASS_HOST_DEVICE
697
+ int get_k() const {
698
+ return iterator_.get_k();
699
+ }
700
+
701
+ CUTLASS_HOST_DEVICE
702
+ int get_stride() const {
703
+ return iterator_.get_stride();
704
+ }
705
+
706
+ /// Advances to the next tile in memory.
707
+ ///
708
+ /// The first time this method is called, predicates are updated, and the
709
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
710
+ /// Subsequent calls are lightweight and must only update the internal
711
+ /// pointer.
712
+ CUTLASS_HOST_DEVICE
713
+ EllPredicatedTileAccessIterator &operator++() {
714
+ ++iterator_;
715
+ return *this;
716
+ }
717
+
718
+ /// Advances to the next tile in memory.
719
+ ///
720
+ /// The first time this method is called, predicates are updated, and the
721
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
722
+ /// Subsequent calls are lightweight and must only update the internal
723
+ /// pointer.
724
+ CUTLASS_HOST_DEVICE
725
+ EllPredicatedTileAccessIterator operator++(int) {
726
+ EllPredicatedTileAccessIterator self(*this);
727
+ operator++();
728
+ return self;
729
+ }
730
+
731
+ /// Clears the predicate set efficiently
732
+ CUTLASS_HOST_DEVICE
733
+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
734
+
735
+ /// Clears the predicate set efficiently
736
+ CUTLASS_HOST_DEVICE
737
+ void enable_mask() { iterator_.enable_mask(); }
738
+
739
+ /// Sets the predicate mask, overriding value stored in predicate iterator
740
+ CUTLASS_HOST_DEVICE
741
+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
742
+
743
+ /// Gets the mask
744
+ CUTLASS_HOST_DEVICE
745
+ void get_mask(Mask &mask) { iterator_.get_mask(mask); }
746
+
747
+ /// add mask for small tiles in ELL
748
+ CUTLASS_DEVICE
749
+ void ell_add_mask(int blocksize) {
750
+ iterator_.ell_add_mask(blocksize);
751
+ }
752
+
753
+ /// Returns whether access is valid or not
754
+ CUTLASS_HOST_DEVICE
755
+ bool valid() {
756
+ return iterator_.valid();
757
+ }
758
+ };
759
+
760
+ ////////////////////////////////////////////////////////////////////////////////
761
+
762
+ /// Specialization of EllPredicatedTileAccessIterator for pitch-linear data.
763
+ ///
764
+ /// Satisfies: ForwardTileIteratorConcept |
765
+ /// ReadableContiguousTileIteratorConcept |
766
+ /// WriteableContiguousTileIteratorConcept |
767
+ /// MaskedTileIteratorConcept
768
+ ///
769
+ template <typename Shape_, typename Element_, int AdvanceRank,
770
+ typename ThreadMap_, typename AccessType_>
771
+ class EllPredicatedTileAccessIterator<Shape_, Element_, layout::RowMajor,
772
+ AdvanceRank, ThreadMap_, AccessType_> {
773
+ public:
774
+ static_assert(
775
+ AdvanceRank == 0 || AdvanceRank == 1,
776
+ "Specialization for pitch-linear iterator may along advance along the "
777
+ "contiguous(rank=0) or strided(rank=1) dimension.");
778
+
779
+ using Shape = Shape_;
780
+ using Element = Element_;
781
+ using Layout = layout::RowMajor;
782
+ static int const kAdvanceRank = AdvanceRank;
783
+ using ThreadMap = ThreadMap_;
784
+ using AccessType = AccessType_;
785
+
786
+ using Index = typename Layout::Index;
787
+ using LongIndex = typename Layout::LongIndex;
788
+
789
+ using TensorRef = TensorRef<Element, Layout>;
790
+ using TensorView = TensorView<Element, Layout>;
791
+ using TensorCoord = typename Layout::TensorCoord;
792
+
793
+ using Pointer = Element *;
794
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
795
+
796
+ using UnderlyingIterator = EllPredicatedTileAccessIterator<
797
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
798
+ layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>;
799
+
800
+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
801
+
802
+ /// Predicate vector stores mask to guard accesses
803
+ using Mask = typename UnderlyingIterator::Mask;
804
+
805
+ /// Parameters object is precomputed state and is host-constructible
806
+ class Params {
807
+ private:
808
+ friend EllPredicatedTileAccessIterator;
809
+
810
+ /// Parameters object
811
+ typename UnderlyingIterator::Params params_;
812
+
813
+ public:
814
+
815
+ /// Default ctor
816
+ CUTLASS_HOST_DEVICE
817
+ Params() { }
818
+
819
+ /// Construct the Params object given a pitch-linear tensor's layout
820
+ CUTLASS_HOST_DEVICE
821
+ Params(Layout const &layout)
822
+ : params_(layout::PitchLinear(layout.stride(0))){};
823
+ };
824
+
825
+ private:
826
+ //
827
+ // Data members
828
+ //
829
+
830
+ /// Underlying pitch-linear tile iterator
831
+ UnderlyingIterator iterator_;
832
+
833
+ public:
834
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
835
+ /// and thread ID
836
+ CUTLASS_HOST_DEVICE
837
+ EllPredicatedTileAccessIterator(
838
+ ///< Precomputed parameters object
839
+ Params const &params,
840
+ ///< Pointer to start of tensor
841
+ Pointer pointer,
842
+ ///< Extent of tensor
843
+ TensorCoord extent,
844
+ ///< ID of each participating thread
845
+ int thread_id,
846
+ ///< Initial offset of threadblock
847
+ TensorCoord const &threadblock_offset)
848
+ : iterator_(params.params_, pointer,
849
+ layout::PitchLinearCoord(extent.column(), extent.row()),
850
+ thread_id,
851
+ layout::PitchLinearCoord(threadblock_offset.column(),
852
+ threadblock_offset.row())) {}
853
+
854
+ /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset
855
+ CUTLASS_HOST_DEVICE
856
+ EllPredicatedTileAccessIterator(
857
+ Params const &params, ///< Precomputed parameters object
858
+ Pointer pointer, ///< Pointer to start of tensor
859
+ TensorCoord extent, ///< Extent of tensor
860
+ int thread_id ///< ID of each participating thread
861
+ )
862
+ : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id,
863
+ make_Coord(0, 0)) {}
864
+
865
+ /// Overrides the internal iteration index
866
+ CUTLASS_HOST_DEVICE
867
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
868
+
869
+ /// Adds a pointer offset in units of Element
870
+ CUTLASS_HOST_DEVICE
871
+ void add_pointer_offset(LongIndex pointer_offset) {
872
+ iterator_.add_pointer_offset(pointer_offset);
873
+ }
874
+
875
+ /// Advances an iterator along logical dimensions of matrix in units of whole
876
+ /// tiles
877
+ CUTLASS_HOST_DEVICE
878
+ void add_tile_offset(TensorCoord const &tile_offset) {
879
+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
880
+ }
881
+
882
+ /// Returns a pointer
883
+ CUTLASS_HOST_DEVICE
884
+ AccessType *get() const {
885
+ return reinterpret_cast<AccessType *>(iterator_.get());
886
+ }
887
+
888
+ CUTLASS_HOST_DEVICE
889
+ int get_k() const {
890
+ return iterator_.get_k();
891
+ }
892
+
893
+ CUTLASS_HOST_DEVICE
894
+ int get_stride() const {
895
+ return iterator_.get_stride();
896
+ }
897
+
898
+ /// Advances to the next tile in memory.
899
+ ///
900
+ /// The first time this method is called, predicates are updated, and the
901
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
902
+ /// Subsequent calls are lightweight and must only update the internal
903
+ /// pointer.
904
+ CUTLASS_HOST_DEVICE
905
+ EllPredicatedTileAccessIterator &operator++() {
906
+ ++iterator_;
907
+ return *this;
908
+ }
909
+
910
+ /// Advances to the next tile in memory.
911
+ ///
912
+ /// The first time this method is called, predicates are updated, and the
913
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
914
+ /// Subsequent calls are lightweight and must only update the internal
915
+ /// pointer.
916
+ CUTLASS_HOST_DEVICE
917
+ EllPredicatedTileAccessIterator operator++(int) {
918
+ EllPredicatedTileAccessIterator self(*this);
919
+ operator++();
920
+ return self;
921
+ }
922
+
923
+ /// Clears the predicate set efficiently
924
+ CUTLASS_HOST_DEVICE
925
+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
926
+
927
+ /// Clears the predicate set efficiently
928
+ CUTLASS_HOST_DEVICE
929
+ void enable_mask() { iterator_.enable_mask(); }
930
+
931
+ /// Sets the predicate mask, overriding value stored in predicate iterator
932
+ CUTLASS_HOST_DEVICE
933
+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
934
+
935
+ /// Gets the mask
936
+ CUTLASS_HOST_DEVICE
937
+ void get_mask(Mask &mask) { iterator_.get_mask(mask); }
938
+
939
+ /// add mask for small tiles in ELL
940
+ CUTLASS_DEVICE
941
+ void ell_add_mask(int blocksize) {
942
+ iterator_.ell_add_mask(blocksize);
943
+ }
944
+
945
+ /// Returns whether access is valid or not
946
+ CUTLASS_HOST_DEVICE
947
+ bool valid() {
948
+ return iterator_.valid();
949
+ }
950
+ };
951
+
952
+ ////////////////////////////////////////////////////////////////////////////////
953
+
954
+ /// Specialization of EllPredicatedTileAccessIterator for column-major interleaved data.
955
+ /// It is mapped to the congruous layout.
956
+ ///
957
+ /// Satisfies: ForwardTileIteratorConcept |
958
+ /// ReadableContiguousTileIteratorConcept |
959
+ /// WriteableContiguousTileIteratorConcept |
960
+ /// MaskedTileIteratorConcept
961
+ ///
962
+
963
+ template <typename Shape_, typename Element_, int AdvanceRank,
964
+ typename ThreadMap_, typename AccessType_, int InterleavedK>
965
+ class EllPredicatedTileAccessIterator<Shape_, Element_,
966
+ layout::ColumnMajorInterleaved<InterleavedK>,
967
+ AdvanceRank, ThreadMap_, AccessType_> {
968
+ public:
969
+ static_assert(
970
+ AdvanceRank == 0 || AdvanceRank == 1,
971
+ "Specialization for pitch-linear iterator may along advance along the "
972
+ "contiguous(rank=0) or strided(rank=1) dimension.");
973
+
974
+ using Shape = Shape_;
975
+ using Element = Element_;
976
+ static int const kInterleavedK = InterleavedK;
977
+ using Layout = layout::ColumnMajorInterleaved<kInterleavedK>;
978
+ static int const kAdvanceRank = AdvanceRank;
979
+ using ThreadMap = ThreadMap_;
980
+ using AccessType = AccessType_;
981
+
982
+ using Index = typename Layout::Index;
983
+ using LongIndex = typename Layout::LongIndex;
984
+
985
+ using TensorRef = TensorRef<Element, Layout>;
986
+ using TensorView = TensorView<Element, Layout>;
987
+ using TensorCoord = typename Layout::TensorCoord;
988
+
989
+ using Pointer = Element *;
990
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
991
+
992
+ using UnderlyingIterator = EllPredicatedTileAccessIterator<
993
+ layout::PitchLinearShape<Shape::kRow * kInterleavedK,
994
+ Shape::kColumn / kInterleavedK>,
995
+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap,
996
+ AccessType>;
997
+
998
+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
999
+
1000
+ /// Predicate vector stores mask to guard accesses
1001
+ using Mask = typename UnderlyingIterator::Mask;
1002
+
1003
+ /// Parameters object is precomputed state and is host-constructible
1004
+ class Params {
1005
+ private:
1006
+ friend EllPredicatedTileAccessIterator;
1007
+
1008
+ /// Parameters object
1009
+ typename UnderlyingIterator::Params params_;
1010
+
1011
+ public:
1012
+ CUTLASS_HOST_DEVICE
1013
+ Params() {}
1014
+
1015
+ /// Construct the Params object given a pitch-linear tensor's layout
1016
+ CUTLASS_HOST_DEVICE
1017
+ Params(Layout const &layout)
1018
+ : params_(layout::PitchLinear(layout.stride(0))) {}
1019
+ };
1020
+
1021
+ private:
1022
+ //
1023
+ // Data members
1024
+ //
1025
+
1026
+ /// Underlying pitch-linear tile iterator
1027
+ UnderlyingIterator iterator_;
1028
+
1029
+ public:
1030
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
1031
+ /// and thread ID
1032
+ CUTLASS_HOST_DEVICE
1033
+ EllPredicatedTileAccessIterator(
1034
+ /// Precomputed parameters object
1035
+ Params const &params,
1036
+ /// Pointer to start of tensor
1037
+ Pointer pointer,
1038
+ /// Extent of tensor
1039
+ TensorCoord extent,
1040
+ /// ID of each participating thread
1041
+ int thread_id,
1042
+ /// Initial offset of threadblock
1043
+ TensorCoord const &threadblock_offset)
1044
+ : iterator_(params.params_, pointer,
1045
+ layout::PitchLinearCoord(extent.row() * kInterleavedK,
1046
+ extent.column() / kInterleavedK),
1047
+ thread_id,
1048
+ layout::PitchLinearCoord(
1049
+ threadblock_offset.row() * kInterleavedK,
1050
+ threadblock_offset.column() / kInterleavedK)) {}
1051
+
1052
+ /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset
1053
+ CUTLASS_HOST_DEVICE
1054
+ EllPredicatedTileAccessIterator(
1055
+ Params const &params, ///< Precomputed parameters object
1056
+ Pointer pointer, ///< Pointer to start of tensor
1057
+ TensorCoord extent, ///< Extent of tensor
1058
+ int thread_id ///< ID of each participating thread
1059
+ )
1060
+ : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id,
1061
+ make_Coord(0, 0)) {}
1062
+
1063
+ /// Overrides the internal iteration index
1064
+ CUTLASS_HOST_DEVICE
1065
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
1066
+
1067
+ /// Adds a pointer offset in units of Element
1068
+ CUTLASS_HOST_DEVICE
1069
+ void add_pointer_offset(LongIndex pointer_offset) {
1070
+ iterator_.add_pointer_offset(pointer_offset);
1071
+ }
1072
+
1073
+ /// Advances an iterator along logical dimensions of matrix in units of whole
1074
+ /// tiles
1075
+ CUTLASS_HOST_DEVICE
1076
+ void add_tile_offset(TensorCoord const &tile_offset) {
1077
+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
1078
+ }
1079
+
1080
+ /// Returns a pointer
1081
+ CUTLASS_HOST_DEVICE
1082
+ AccessType *get() const {
1083
+ return reinterpret_cast<AccessType *>(iterator_.get());
1084
+ }
1085
+
1086
+ CUTLASS_HOST_DEVICE
1087
+ int get_k() const {
1088
+ return iterator_.get_k();
1089
+ }
1090
+
1091
+ CUTLASS_HOST_DEVICE
1092
+ int get_stride() const {
1093
+ return iterator_.get_stride();
1094
+ }
1095
+
1096
+ /// Advances to the next tile in memory.
1097
+ ///
1098
+ /// The first time this method is called, predicates are updated, and the
1099
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1100
+ /// Subsequent calls are lightweight and must only update the internal
1101
+ /// pointer.
1102
+ CUTLASS_HOST_DEVICE
1103
+ EllPredicatedTileAccessIterator &operator++() {
1104
+ ++iterator_;
1105
+ return *this;
1106
+ }
1107
+
1108
+ /// Advances to the next tile in memory.
1109
+ ///
1110
+ /// The first time this method is called, predicates are updated, and the
1111
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1112
+ /// Subsequent calls are lightweight and must only update the internal
1113
+ /// pointer.
1114
+ CUTLASS_HOST_DEVICE
1115
+ EllPredicatedTileAccessIterator operator++(int) {
1116
+ EllPredicatedTileAccessIterator self(*this);
1117
+ operator++();
1118
+ return self;
1119
+ }
1120
+
1121
+ /// Clears the predicate set efficiently
1122
+ CUTLASS_HOST_DEVICE
1123
+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
1124
+
1125
+ /// Clears the predicate set efficiently
1126
+ CUTLASS_HOST_DEVICE
1127
+ void enable_mask() { iterator_.enable_mask(); }
1128
+
1129
+ /// Sets the predicate mask, overriding value stored in predicate iterator
1130
+ CUTLASS_HOST_DEVICE
1131
+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
1132
+
1133
+ /// Gets the mask
1134
+ CUTLASS_HOST_DEVICE
1135
+ void get_mask(Mask &mask) { iterator_.get_mask(mask); }
1136
+
1137
+ /// add mask for small tiles in ELL
1138
+ CUTLASS_DEVICE
1139
+ void ell_add_mask(int blocksize) {
1140
+ iterator_.ell_add_mask(blocksize);
1141
+ }
1142
+
1143
+ /// Returns whether access is valid or not
1144
+ CUTLASS_HOST_DEVICE
1145
+ bool valid() { return iterator_.valid(); }
1146
+ };
1147
+
1148
+ ////////////////////////////////////////////////////////////////////////////////
1149
+
1150
+ /// Specialization of EllPredicatedTileAccessIterator for row-major interleaved data.
1151
+ /// It is mapped to the congruous layout.
1152
+ ///
1153
+ /// Satisfies: ForwardTileIteratorConcept |
1154
+ /// ReadableContiguousTileIteratorConcept |
1155
+ /// WriteableContiguousTileIteratorConcept |
1156
+ /// MaskedTileIteratorConcept
1157
+ ///
1158
+ template <typename Shape_, typename Element_, int AdvanceRank,
1159
+ typename ThreadMap_, typename AccessType_, int InterleavedK>
1160
+ class EllPredicatedTileAccessIterator<Shape_, Element_,
1161
+ layout::RowMajorInterleaved<InterleavedK>,
1162
+ AdvanceRank, ThreadMap_, AccessType_> {
1163
+ public:
1164
+ static_assert(
1165
+ AdvanceRank == 0 || AdvanceRank == 1,
1166
+ "Specialization for pitch-linear iterator may along advance along the "
1167
+ "contiguous(rank=0) or strided(rank=1) dimension.");
1168
+
1169
+ using Shape = Shape_;
1170
+ using Element = Element_;
1171
+ static int const kInterleavedK = InterleavedK;
1172
+ using Layout = layout::RowMajorInterleaved<kInterleavedK>;
1173
+ static int const kAdvanceRank = AdvanceRank;
1174
+ using ThreadMap = ThreadMap_;
1175
+ using AccessType = AccessType_;
1176
+
1177
+ using Index = typename Layout::Index;
1178
+ using LongIndex = typename Layout::LongIndex;
1179
+
1180
+ using TensorRef = TensorRef<Element, Layout>;
1181
+ using TensorView = TensorView<Element, Layout>;
1182
+ using TensorCoord = typename Layout::TensorCoord;
1183
+
1184
+ using Pointer = Element *;
1185
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
1186
+
1187
+ using UnderlyingIterator = EllPredicatedTileAccessIterator<
1188
+ layout::PitchLinearShape<Shape::kColumn * kInterleavedK,
1189
+ Shape::kRow / kInterleavedK>,
1190
+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap,
1191
+ AccessType>;
1192
+
1193
+
1194
+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
1195
+
1196
+ /// Predicate vector stores mask to guard accesses
1197
+ using Mask = typename UnderlyingIterator::Mask;
1198
+
1199
+ /// Parameters object is precomputed state and is host-constructible
1200
+ class Params {
1201
+ private:
1202
+ friend EllPredicatedTileAccessIterator;
1203
+
1204
+ /// Parameters object
1205
+ typename UnderlyingIterator::Params params_;
1206
+
1207
+ public:
1208
+ CUTLASS_HOST_DEVICE
1209
+ Params() {}
1210
+
1211
+ /// Construct the Params object given a pitch-linear tensor's layout
1212
+ CUTLASS_HOST_DEVICE
1213
+ Params(Layout const &layout)
1214
+ : params_(layout::PitchLinear(layout.stride(0))) {}
1215
+ };
1216
+
1217
+ private:
1218
+ //
1219
+ // Data members
1220
+ //
1221
+
1222
+ /// Underlying pitch-linear tile iterator
1223
+ UnderlyingIterator iterator_;
1224
+
1225
+ public:
1226
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
1227
+ /// and thread ID
1228
+ CUTLASS_HOST_DEVICE
1229
+ EllPredicatedTileAccessIterator(
1230
+ /// Precomputed parameters object
1231
+ Params const &params,
1232
+ /// Pointer to start of tensor
1233
+ Pointer pointer,
1234
+ /// Extent of tensor
1235
+ TensorCoord extent,
1236
+ /// ID of each participating thread
1237
+ int thread_id,
1238
+ /// Initial offset of threadblock
1239
+ TensorCoord const &threadblock_offset)
1240
+ : iterator_(params.params_, pointer,
1241
+ layout::PitchLinearCoord(extent.column() * kInterleavedK,
1242
+ extent.row() / kInterleavedK),
1243
+ thread_id,
1244
+ layout::PitchLinearCoord(
1245
+ threadblock_offset.column() * kInterleavedK,
1246
+ threadblock_offset.row() / kInterleavedK)) {}
1247
+
1248
+ /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset
1249
+ CUTLASS_HOST_DEVICE
1250
+ EllPredicatedTileAccessIterator(
1251
+ Params const &params, ///< Precomputed parameters object
1252
+ Pointer pointer, ///< Pointer to start of tensor
1253
+ TensorCoord extent, ///< Extent of tensor
1254
+ int thread_id ///< ID of each participating thread
1255
+ )
1256
+ : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id,
1257
+ make_Coord(0, 0)) {}
1258
+
1259
+ /// Overrides the internal iteration index
1260
+ CUTLASS_HOST_DEVICE
1261
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
1262
+
1263
+ /// Adds a pointer offset in units of Element
1264
+ CUTLASS_HOST_DEVICE
1265
+ void add_pointer_offset(LongIndex pointer_offset) {
1266
+ iterator_.add_pointer_offset(pointer_offset);
1267
+ }
1268
+
1269
+ /// Advances an iterator along logical dimensions of matrix in units of whole
1270
+ /// tiles
1271
+ CUTLASS_HOST_DEVICE
1272
+ void add_tile_offset(TensorCoord const &tile_offset) {
1273
+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
1274
+ }
1275
+
1276
+ /// Returns a pointer
1277
+ CUTLASS_HOST_DEVICE
1278
+ AccessType *get() const {
1279
+ return reinterpret_cast<AccessType *>(iterator_.get());
1280
+ }
1281
+
1282
+ CUTLASS_HOST_DEVICE
1283
+ int get_k() const {
1284
+ return iterator_.get_k();
1285
+ }
1286
+
1287
+ CUTLASS_HOST_DEVICE
1288
+ int get_stride() const {
1289
+ return iterator_.get_stride();
1290
+ }
1291
+
1292
+ /// Advances to the next tile in memory.
1293
+ ///
1294
+ /// The first time this method is called, predicates are updated, and the
1295
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1296
+ /// Subsequent calls are lightweight and must only update the internal
1297
+ /// pointer.
1298
+ CUTLASS_HOST_DEVICE
1299
+ EllPredicatedTileAccessIterator &operator++() {
1300
+ ++iterator_;
1301
+ return *this;
1302
+ }
1303
+
1304
+ /// Advances to the next tile in memory.
1305
+ ///
1306
+ /// The first time this method is called, predicates are updated, and the
1307
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1308
+ /// Subsequent calls are lightweight and must only update the internal
1309
+ /// pointer.
1310
+ CUTLASS_HOST_DEVICE
1311
+ EllPredicatedTileAccessIterator operator++(int) {
1312
+ EllPredicatedTileAccessIterator self(*this);
1313
+ operator++();
1314
+ return self;
1315
+ }
1316
+
1317
+ /// Clears the predicate set efficiently
1318
+ CUTLASS_HOST_DEVICE
1319
+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
1320
+
1321
+ /// Clears the predicate set efficiently
1322
+ CUTLASS_HOST_DEVICE
1323
+ void enable_mask() { iterator_.enable_mask(); }
1324
+
1325
+ /// Sets the predicate mask, overriding value stored in predicate iterator
1326
+ CUTLASS_HOST_DEVICE
1327
+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
1328
+
1329
+ /// Gets the mask
1330
+ CUTLASS_HOST_DEVICE
1331
+ void get_mask(Mask &mask) { iterator_.get_mask(mask); }
1332
+
1333
+ /// add mask for small tiles in ELL
1334
+ CUTLASS_DEVICE
1335
+ void ell_add_mask(int blocksize) {
1336
+ iterator_.ell_add_mask(blocksize);
1337
+ }
1338
+
1339
+ /// Returns whether access is valid or not
1340
+ CUTLASS_HOST_DEVICE
1341
+ bool valid() { return iterator_.valid(); }
1342
+ };
1343
+
1344
+ ////////////////////////////////////////////////////////////////////////////////
1345
+
1346
+ } // namespace threadblock
1347
+ } // namespace transform
1348
+ } // namespace cutlass
1349
+
1350
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h ADDED
@@ -0,0 +1,1315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Ell iterator for Blocked-Ell matrix (ellValue matrix) used with EllMmaPipelined
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/arch/memory.h"
38
+ #include "cutlass/transform/threadblock/predicated_tile_access_iterator.h"
39
+
40
+ #include "cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h"
41
+ #include "cutlass/transform/threadblock/ell_iterator.h"
42
+
43
+ ////////////////////////////////////////////////////////////////////////////////
44
+
45
+ namespace cutlass {
46
+ namespace transform {
47
+ namespace threadblock {
48
+
49
+ ////////////////////////////////////////////////////////////////////////////////
50
+
51
+ /// EllPredicatedTileIterator
52
+ ///
53
+ /// Satisfies: ForwardTileIteratorConcept |
54
+ /// ReadableContiguousTileIteratorConcept |
55
+ /// WriteableContiguousTileIteratorConcept |
56
+ /// MaskedTileIteratorConcept
57
+ ///
58
+ /// Regular tile iterator using a precomputed control structure to minimize register liveness
59
+ /// and integer arithmetic.
60
+ ///
61
+ /// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed.
62
+ ///
63
+ /// Base pointer and tensor extents may be specified at the time the iterator is constructed.
64
+ /// Subsequently, they are assumed to be immutable.
65
+ ///
66
+ /// Adding a logical coordinate offset may be performed at the time the iterator is constructed.
67
+ /// Subsequent additions to logical coordinate offset may be performed but are relatively expensive.
68
+ ///
69
+ /// Visitation order is intended to first visit a "residual" tile that may be partially full in
70
+ /// both the advance dimension and the steady-state dimension. This is assumed to be the last
71
+ /// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to
72
+ /// the first tile that is full in the advance dimension and recomputes predicates. Subsequent
73
+ /// accesses may be performed without updating internal predicates and are efficient in terms of
74
+ /// live register state and pointer arithmetic instructions.
75
+ ///
76
+ /// To be efficient, this assumes the iterator will be dereferenced and advanced at least once
77
+ /// outside any looping structure to minimize integer arithmetic.
78
+ ///
79
+ /// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing
80
+ /// the iterator.
81
+ ///
82
+ ///
83
+ /// Example:
84
+ ///
85
+ /// An efficient pipeline structure may be constructed as follows:
86
+ ///
87
+ // template <typename Iterator>
88
+ // __global__ void kernel(
89
+ // typename Iterator::Params params,
90
+ // typename Iterator::Element *ptr,
91
+ // TensorCoord extent) {
92
+ //
93
+ // typename Iterator::Fragment fragment;
94
+ //
95
+ // TensorCoord threadblock_offset(0, 0);
96
+ //
97
+ // Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets);
98
+ //
99
+ //
100
+ // fragment = *iter; // load "residue" tile first
101
+ // ++iter; // advance to first "steady state" tile and update internal masks
102
+ //
103
+ //
104
+ // #pragma unroll
105
+ // for (int i = Remaining - 1; i >= 0; --i) {
106
+ //
107
+ // f(fragment);
108
+ //
109
+ // if (!i) {
110
+ // iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs.
111
+ // }
112
+ //
113
+ // fragment = *iter; // load tile during "steady state" phase
114
+ // ++iter; // advance to next tile - lightweight due to steady-state masks
115
+ // }
116
+ // }
117
+ //
118
+ // void host(TensorView<Element, 2, layout::PitchLinear> view) {
119
+ //
120
+ // using Iterator = transform::threadblock::EllPredicatedTileIterator;
121
+ //
122
+ // typename Iterator::Params params(view.layout());
123
+ //
124
+ // kernel<Iterator>(params, view.data());
125
+ // }
126
+ ///
127
+ ///
128
+ template <
129
+ typename Shape,
130
+ typename Element,
131
+ typename Layout,
132
+ int AdvanceRank,
133
+ typename ThreadMap,
134
+ int AccessSize = ThreadMap::kElementsPerAccess
135
+ >
136
+ class EllPredicatedTileIterator;
137
+
138
+ ////////////////////////////////////////////////////////////////////////////////
139
+
140
+ /// Specialization of EllPredicatedTileIterator for pitch-linear data.
141
+ ///
142
+ /// Satisfies: ForwardTileIteratorConcept |
143
+ /// ReadableContiguousTileIteratorConcept |
144
+ /// WriteableContiguousTileIteratorConcept |
145
+ /// MaskedTileIteratorConcept
146
+ ///
147
+ template <typename Shape_, typename Element_, int AdvanceRank,
148
+ typename ThreadMap_, int AccessSize>
149
+ class EllPredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
150
+ ThreadMap_, AccessSize> {
151
+ public:
152
+ static_assert(
153
+ AdvanceRank == 0 || AdvanceRank == 1,
154
+ "Specialization for pitch-linear iterator may along advance along the "
155
+ "contiguous(rank=0) or strided(rank=1) dimension.");
156
+
157
+ using Shape = Shape_;
158
+ using Element = Element_;
159
+ using Layout = layout::PitchLinear;
160
+ static int const kAdvanceRank = AdvanceRank;
161
+ using ThreadMap = ThreadMap_;
162
+
163
+ using Index = typename Layout::Index;
164
+ using LongIndex = typename Layout::LongIndex;
165
+
166
+ using TensorRef = TensorRef<Element, Layout>;
167
+ using TensorView = TensorView<Element, Layout>;
168
+ using TensorCoord = typename Layout::TensorCoord;
169
+
170
+ using Pointer = Element *;
171
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
172
+
173
+ /// Type used for internal memory accesses
174
+ using AccessType = AlignedArray<Element, AccessSize, (AccessSize * sizeof_bits<Element>::value / 8)>;
175
+
176
+ /// Underlying iterator to compute the addresses
177
+ using TileAccessIterator =
178
+ EllPredicatedTileAccessIterator<Shape, Element, Layout, kAdvanceRank,
179
+ ThreadMap, AccessType>;
180
+
181
+ static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
182
+
183
+ /// Fragment object to be loaded or stored
184
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
185
+ ThreadMap::kElementsPerAccess>;
186
+
187
+ /// Predicate vector stores mask to guard accesses
188
+ using Mask = typename TileAccessIterator::Mask;
189
+
190
+ /// Iterator for ELL storage
191
+ using EllIterator = typename cutlass::transform::threadblock::ell::Iterator;
192
+
193
+ /// Parameters object is precomputed state and is host-constructible
194
+ class Params {
195
+ public:
196
+ friend EllPredicatedTileIterator;
197
+
198
+ private:
199
+ /// Parameters object
200
+ typename TileAccessIterator::Params params_;
201
+
202
+ public:
203
+ /// Construct the Params object given a pitch-linear tensor's layout
204
+ CUTLASS_HOST_DEVICE
205
+ Params(Layout const &layout) : params_(layout) { }
206
+
207
+ CUTLASS_HOST_DEVICE
208
+ Params() { }
209
+ };
210
+
211
+ private:
212
+ /// Internal pointer type permits fast address arithmetic
213
+ using BytePointer = char *;
214
+
215
+ private:
216
+ //
217
+ // Data members
218
+ //
219
+
220
+ /// Data member to the tile access iterator
221
+ TileAccessIterator address_iterator_;
222
+
223
+ public:
224
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
225
+ /// and thread ID
226
+ CUTLASS_HOST_DEVICE
227
+ EllPredicatedTileIterator(
228
+ /// Precomputed parameters object
229
+ Params const &params,
230
+ /// Pointer to start of tensor
231
+ Pointer pointer,
232
+ /// Extent of tensor
233
+ TensorCoord extent,
234
+ /// ID of each participating thread
235
+ int thread_id,
236
+ /// Initial offset of threadblock
237
+ TensorCoord const &threadblock_offset)
238
+ : address_iterator_(params.params_, pointer, extent, thread_id,
239
+ threadblock_offset) {}
240
+
241
+ /// Construct a EllPredicatedTileIterator with zero threadblock offset
242
+ CUTLASS_HOST_DEVICE
243
+ EllPredicatedTileIterator(
244
+ Params const &params, ///< Precomputed parameters object
245
+ Pointer pointer, ///< Pointer to start of tensor
246
+ TensorCoord extent, ///< Extent of tensor
247
+ int thread_id ///< ID of each participating thread
248
+ )
249
+ : EllPredicatedTileIterator(params, pointer, extent, thread_id,
250
+ make_Coord(0, 0)) {}
251
+
252
+ /// Adds a pointer offset in units of Element
253
+ CUTLASS_HOST_DEVICE
254
+ void add_pointer_offset(LongIndex pointer_offset) {
255
+ address_iterator_.add_pointer_offset(pointer_offset);
256
+ }
257
+
258
+ /// Advances to the next tile in memory.
259
+ ///
260
+ /// The first time this method is called, predicates are updated, and the
261
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
262
+ /// Subsequent calls are lightweight and must only update the internal
263
+ /// pointer.
264
+ CUTLASS_HOST_DEVICE
265
+ EllPredicatedTileIterator &operator++() {
266
+ if (kAdvanceRank)
267
+ address_iterator_.add_tile_offset({0, 1});
268
+ else
269
+ address_iterator_.add_tile_offset({1, 0});
270
+
271
+ return *this;
272
+ }
273
+
274
+ /// Advances to the next tile in memory.
275
+ ///
276
+ /// The first time this method is called, predicates are updated, and the
277
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
278
+ /// Subsequent calls are lightweight and must only update the internal
279
+ /// pointer.
280
+ CUTLASS_HOST_DEVICE
281
+ EllPredicatedTileIterator operator++(int) {
282
+ EllPredicatedTileIterator self(*this);
283
+ operator++();
284
+ return self;
285
+ }
286
+
287
+ /// Returns a stride
288
+ CUTLASS_HOST_DEVICE
289
+ int get_stride() const { return address_iterator_.get_stride(); }
290
+
291
+ /// Clears the predicate set efficiently
292
+ CUTLASS_HOST_DEVICE
293
+ void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
294
+
295
+ /// Clears the predicate set efficiently
296
+ CUTLASS_HOST_DEVICE
297
+ void enable_mask() { address_iterator_.enable_mask(); }
298
+
299
+ /// Sets the predicate mask, overriding value stored in predicate iterator
300
+ CUTLASS_HOST_DEVICE
301
+ void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); }
302
+
303
+ /// Gets the mask
304
+ CUTLASS_HOST_DEVICE
305
+ void get_mask(Mask &mask) { address_iterator_.get_mask(mask); }
306
+
307
+ /// add mask for small tiles in ELL
308
+ CUTLASS_HOST_DEVICE
309
+ void ell_add_mask(int blocksize) { address_iterator_.ell_add_mask(blocksize); }
310
+
311
+ CUTLASS_DEVICE
312
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
313
+ load_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
314
+ }
315
+
316
+ CUTLASS_DEVICE
317
+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
318
+
319
+ AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
320
+
321
+ CUTLASS_PRAGMA_UNROLL
322
+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
323
+ CUTLASS_PRAGMA_UNROLL
324
+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
325
+
326
+ CUTLASS_PRAGMA_UNROLL
327
+ for (int v = 0; v < kAccessesPerVector; ++v) {
328
+
329
+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
330
+
331
+ address_iterator_.set_iteration_index(idx);
332
+ char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + byte_offset;
333
+
334
+ AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
335
+
336
+ cutlass::arch::global_load<AccessType,
337
+ sizeof(AccessType)
338
+ >(
339
+ frag_ptr[idx], access_ptr, address_iterator_.valid());
340
+
341
+ ++address_iterator_;
342
+ }
343
+ }
344
+ }
345
+ }
346
+
347
+ /// Loads a fragment from memory
348
+ CUTLASS_DEVICE
349
+ void load(Fragment &frag) { load_with_byte_offset(frag, 0); }
350
+
351
+ CUTLASS_DEVICE
352
+ void load_with_ell_index(Fragment &frag, EllIterator &ell_iter) {
353
+
354
+ AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
355
+
356
+ CUTLASS_PRAGMA_UNROLL
357
+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
358
+ CUTLASS_PRAGMA_UNROLL
359
+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
360
+ CUTLASS_PRAGMA_UNROLL
361
+ for (int v = 0; v < kAccessesPerVector; ++v) {
362
+
363
+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
364
+ address_iterator_.set_iteration_index(idx);
365
+ LongIndex ell_offset = 0;
366
+
367
+ int k_offset = address_iterator_.get_k();
368
+ ell_offset = ell_iter.get_offset(k_offset) * sizeof(Element);
369
+
370
+ char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + ell_offset;
371
+
372
+ AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
373
+
374
+ bool is_valid = address_iterator_.valid();
375
+ is_valid = is_valid && (ell_offset >= 0);
376
+
377
+ cutlass::arch::global_load<AccessType,
378
+ sizeof(AccessType)
379
+ >(
380
+ frag_ptr[idx], access_ptr, is_valid);
381
+
382
+ ++address_iterator_;
383
+ }
384
+ }
385
+ }
386
+ }
387
+
388
+ CUTLASS_DEVICE
389
+ void load_with_ell_index_fast(Fragment &frag, EllIterator &ell_iter) {
390
+
391
+ LongIndex ell_offset = ell_iter.get_offset_fast() * sizeof(Element);
392
+
393
+ AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
394
+
395
+ CUTLASS_PRAGMA_UNROLL
396
+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
397
+ CUTLASS_PRAGMA_UNROLL
398
+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
399
+
400
+ CUTLASS_PRAGMA_UNROLL
401
+ for (int v = 0; v < kAccessesPerVector; ++v) {
402
+
403
+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
404
+
405
+ address_iterator_.set_iteration_index(idx);
406
+ char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + ell_offset;
407
+
408
+ AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
409
+
410
+ bool is_valid = address_iterator_.valid();
411
+ is_valid = is_valid && (ell_offset >= 0);
412
+
413
+ cutlass::arch::global_load<AccessType,
414
+ sizeof(AccessType)
415
+ >(
416
+ frag_ptr[idx], access_ptr, is_valid);
417
+
418
+ ++address_iterator_;
419
+ }
420
+ }
421
+ }
422
+ }
423
+ /// Store a fragment to memory
424
+ CUTLASS_DEVICE
425
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
426
+ store_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
427
+ }
428
+
429
+ /// Store a fragment to memory
430
+ CUTLASS_DEVICE
431
+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
432
+ address_iterator_.set_iteration_index(0);
433
+ AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
434
+
435
+ CUTLASS_PRAGMA_UNROLL
436
+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
437
+ CUTLASS_PRAGMA_UNROLL
438
+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
439
+ CUTLASS_PRAGMA_UNROLL
440
+ for (int v = 0; v < kAccessesPerVector; ++v) {
441
+
442
+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
443
+
444
+ char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
445
+ AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
446
+
447
+ if (address_iterator_.valid()) {
448
+ *access_ptr = frag_ptr[idx];
449
+ }
450
+ ++address_iterator_;
451
+ }
452
+ }
453
+ }
454
+ }
455
+
456
+ /// Store a fragment to memory
457
+ CUTLASS_DEVICE
458
+ void store(Fragment const &frag) { store_with_byte_offset(frag, 0); }
459
+ };
460
+
461
+ ////////////////////////////////////////////////////////////////////////////////
462
+
463
+ /// Specialization of EllPredicatedTileIterator for pitch-linear data.
464
+ ///
465
+ /// Satisfies: ForwardTileIteratorConcept |
466
+ /// ReadableContiguousTileIteratorConcept |
467
+ /// WriteableContiguousTileIteratorConcept |
468
+ /// MaskedTileIteratorConcept
469
+ ///
470
+ template <
471
+ typename Shape_,
472
+ typename Element_,
473
+ int AdvanceRank,
474
+ typename ThreadMap_,
475
+ int AccessSize
476
+ >
477
+ class EllPredicatedTileIterator<Shape_, Element_, layout::ColumnMajor, AdvanceRank, ThreadMap_, AccessSize> {
478
+ public:
479
+
480
+ static_assert(AdvanceRank == 0 || AdvanceRank == 1,
481
+ "Specialization for pitch-linear iterator may along advance along the "
482
+ "contiguous(rank=0) or strided(rank=1) dimension.");
483
+
484
+ using Shape = Shape_;
485
+ using Element = Element_;
486
+ using Layout = layout::ColumnMajor;
487
+ static int const kAdvanceRank = AdvanceRank;
488
+ using ThreadMap = ThreadMap_;
489
+
490
+ using Index = typename Layout::Index;
491
+ using LongIndex = typename Layout::LongIndex;
492
+
493
+ using TensorRef = TensorRef<Element, Layout>;
494
+ using TensorView = TensorView<Element, Layout>;
495
+ using TensorCoord = typename Layout::TensorCoord;
496
+
497
+ using Pointer = Element *;
498
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
499
+
500
+ using UnderlyingIterator = EllPredicatedTileIterator<
501
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
502
+ Element,
503
+ layout::PitchLinear,
504
+ (kAdvanceRank == 0 ? 0 : 1),
505
+ ThreadMap,
506
+ AccessSize
507
+ >;
508
+
509
+ using AccessType = typename UnderlyingIterator::AccessType;
510
+
511
+ /// Fragment object to be loaded or stored
512
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
513
+
514
+ /// Predicate vector stores mask to guard accesses
515
+ using Mask = typename UnderlyingIterator::Mask;
516
+
517
+ /// Iterator for ELL storage
518
+ using EllIterator = typename cutlass::transform::threadblock::ell::Iterator;
519
+
520
+ /// Parameters object is precomputed state and is host-constructible
521
+ class Params {
522
+ private:
523
+
524
+ friend EllPredicatedTileIterator;
525
+
526
+ /// Parameters object
527
+ typename UnderlyingIterator::Params params_;
528
+
529
+ public:
530
+
531
+ CUTLASS_HOST_DEVICE
532
+ Params() { }
533
+
534
+ /// Construct the Params object given a pitch-linear tensor's layout
535
+ CUTLASS_HOST_DEVICE
536
+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {
537
+
538
+ }
539
+ };
540
+
541
+
542
+ private:
543
+
544
+ //
545
+ // Data members
546
+ //
547
+
548
+ /// Underlying pitch-linear tile iterator
549
+ UnderlyingIterator iterator_;
550
+
551
+ public:
552
+
553
+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
554
+ CUTLASS_HOST_DEVICE
555
+ EllPredicatedTileIterator(
556
+ Params const &params, ///< Precomputed parameters object
557
+ Pointer pointer, ///< Pointer to start of tensor
558
+ TensorCoord extent, ///< Extent of tensor
559
+ int thread_id, ///< ID of each participating thread
560
+ TensorCoord const &threadblock_offset ///< Initial offset of threadblock
561
+ ):
562
+ iterator_(
563
+ params.params_,
564
+ pointer,
565
+ layout::PitchLinearCoord(extent.row(), extent.column()),
566
+ thread_id,
567
+ layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column())
568
+ ) { }
569
+
570
+ /// Construct a EllPredicatedTileIterator with zero threadblock offset
571
+ CUTLASS_HOST_DEVICE
572
+ EllPredicatedTileIterator(
573
+ Params const &params, ///< Precomputed parameters object
574
+ Pointer pointer, ///< Pointer to start of tensor
575
+ TensorCoord extent, ///< Extent of tensor
576
+ int thread_id ///< ID of each participating thread
577
+ ): EllPredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
578
+
579
+ /// Adds a pointer offset in units of Element
580
+ CUTLASS_HOST_DEVICE
581
+ void add_pointer_offset(LongIndex pointer_offset) {
582
+ iterator_.add_pointer_offset(pointer_offset);
583
+ }
584
+
585
+ /// Advances to the next tile in memory.
586
+ ///
587
+ /// The first time this method is called, predicates are updated, and the iterator's
588
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
589
+ /// are lightweight and must only update the internal pointer.
590
+ CUTLASS_HOST_DEVICE
591
+ EllPredicatedTileIterator &operator++() {
592
+ ++iterator_;
593
+ return *this;
594
+ }
595
+
596
+ /// Advances to the next tile in memory.
597
+ ///
598
+ /// The first time this method is called, predicates are updated, and the iterator's
599
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
600
+ /// are lightweight and must only update the internal pointer.
601
+ CUTLASS_HOST_DEVICE
602
+ EllPredicatedTileIterator operator++(int) {
603
+ EllPredicatedTileIterator self(*this);
604
+ operator++();
605
+ return self;
606
+ }
607
+
608
+ /// Returns a stride
609
+ CUTLASS_HOST_DEVICE
610
+ int get_stride() const { return iterator_.get_stride(); }
611
+
612
+ /// Clears the predicate set efficiently
613
+ CUTLASS_HOST_DEVICE
614
+ void clear_mask(bool enable = true) {
615
+ iterator_.clear_mask(enable);
616
+ }
617
+
618
+ /// Clears the predicate set efficiently
619
+ CUTLASS_HOST_DEVICE
620
+ void enable_mask() {
621
+ iterator_.enable_mask();
622
+ }
623
+
624
+ /// Sets the predicate mask, overriding value stored in predicate iterator
625
+ CUTLASS_HOST_DEVICE
626
+ void set_mask(Mask const &mask) {
627
+ iterator_.set_mask(mask);
628
+ }
629
+
630
+ /// Gets the mask
631
+ CUTLASS_HOST_DEVICE
632
+ void get_mask(Mask &mask) {
633
+ iterator_.get_mask(mask);
634
+ }
635
+
636
+ /// add mask for small tiles in ELL
637
+ CUTLASS_HOST_DEVICE
638
+ void ell_add_mask(int blocksize) {
639
+ iterator_.ell_add_mask(blocksize);
640
+ }
641
+
642
+ /// Loads a fragment from memory
643
+ CUTLASS_DEVICE
644
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
645
+ iterator_.load_with_pointer_offset(frag, pointer_offset);
646
+ }
647
+
648
+ /// Loads a fragment from memory
649
+ CUTLASS_DEVICE
650
+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
651
+ iterator_.load_with_byte_offset(frag, byte_offset);
652
+ }
653
+
654
+ /// Loads a fragment from memory
655
+ CUTLASS_DEVICE
656
+ void load(Fragment &frag) {
657
+ load_with_pointer_offset(frag, 0);
658
+ }
659
+
660
+ CUTLASS_DEVICE
661
+ void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) {
662
+ iterator_.load_with_ell_index(frag, ell_iter);
663
+ }
664
+
665
+ CUTLASS_DEVICE
666
+ void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) {
667
+ iterator_.load_with_ell_index_fast(frag, ell_iter);
668
+ }
669
+
670
+ /// Store a fragment to memory
671
+ CUTLASS_DEVICE
672
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
673
+ iterator_.store_with_pointer_offset(frag, pointer_offset);
674
+ }
675
+
676
+ /// Store a fragment to memory
677
+ CUTLASS_DEVICE
678
+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
679
+ iterator_.store_with_byte_offset(frag, byte_offset);
680
+ }
681
+
682
+ /// Store a fragment to memory
683
+ CUTLASS_DEVICE
684
+ void store(Fragment const &frag) {
685
+ store_with_pointer_offset(frag, 0);
686
+ }
687
+ };
688
+
689
+ ////////////////////////////////////////////////////////////////////////////////
690
+
691
+ /// Specialization of EllPredicatedTileIterator for pitch-linear data.
692
+ ///
693
+ /// Satisfies: ForwardTileIteratorConcept |
694
+ /// ReadableContiguousTileIteratorConcept |
695
+ /// WriteableContiguousTileIteratorConcept |
696
+ /// MaskedTileIteratorConcept
697
+ ///
698
+ template <
699
+ typename Shape_,
700
+ typename Element_,
701
+ int AdvanceRank,
702
+ typename ThreadMap_,
703
+ int AccessSize
704
+ >
705
+ class EllPredicatedTileIterator<Shape_, Element_, layout::RowMajor, AdvanceRank, ThreadMap_, AccessSize> {
706
+ public:
707
+
708
+ static_assert(AdvanceRank == 0 || AdvanceRank == 1,
709
+ "Specialization for pitch-linear iterator may along advance along the "
710
+ "contiguous(rank=0) or strided(rank=1) dimension.");
711
+
712
+ using Shape = Shape_;
713
+ using Element = Element_;
714
+ using Layout = layout::RowMajor;
715
+ static int const kAdvanceRank = AdvanceRank;
716
+ using ThreadMap = ThreadMap_;
717
+
718
+ using Index = typename Layout::Index;
719
+ using LongIndex = typename Layout::LongIndex;
720
+
721
+ using TensorRef = TensorRef<Element, Layout>;
722
+ using TensorView = TensorView<Element, Layout>;
723
+ using TensorCoord = typename Layout::TensorCoord;
724
+
725
+ using Pointer = Element *;
726
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
727
+
728
+ using UnderlyingIterator = EllPredicatedTileIterator<
729
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
730
+ Element,
731
+ layout::PitchLinear,
732
+ (kAdvanceRank == 0 ? 1 : 0),
733
+ ThreadMap,
734
+ AccessSize
735
+ >;
736
+
737
+ using AccessType = typename UnderlyingIterator::AccessType;
738
+
739
+ /// Fragment object to be loaded or stored
740
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
741
+
742
+ /// Predicate vector stores mask to guard accesses
743
+ using Mask = typename UnderlyingIterator::Mask;
744
+
745
+ /// Iterator for ELL storage
746
+ using EllIterator = typename cutlass::transform::threadblock::ell::Iterator;
747
+
748
+ /// Parameters object is precomputed state and is host-constructible
749
+ class Params {
750
+ private:
751
+
752
+ friend EllPredicatedTileIterator;
753
+
754
+ /// Parameters object
755
+ typename UnderlyingIterator::Params params_;
756
+
757
+ public:
758
+
759
+ CUTLASS_HOST_DEVICE
760
+ Params() { }
761
+
762
+ /// Construct the Params object given a pitch-linear tensor's layout
763
+ CUTLASS_HOST_DEVICE
764
+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {
765
+
766
+ };
767
+ };
768
+
769
+
770
+ private:
771
+
772
+ //
773
+ // Data members
774
+ //
775
+
776
+ /// Underlying pitch-linear tile iterator
777
+ UnderlyingIterator iterator_;
778
+
779
+ public:
780
+
781
+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
782
+ CUTLASS_HOST_DEVICE
783
+ EllPredicatedTileIterator(
784
+ Params const &params, ///< Precomputed parameters object
785
+ Pointer pointer, ///< Pointer to start of tensor
786
+ TensorCoord extent, ///< Extent of tensor
787
+ int thread_id, ///< ID of each participating thread
788
+ TensorCoord const &threadblock_offset ///< Initial offset of threadblock
789
+ ):
790
+ iterator_(
791
+ params.params_,
792
+ pointer,
793
+ layout::PitchLinearCoord(extent.column(), extent.row()),
794
+ thread_id,
795
+ layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row())
796
+ ) { }
797
+
798
+ /// Construct a EllPredicatedTileIterator with zero threadblock offset
799
+ CUTLASS_HOST_DEVICE
800
+ EllPredicatedTileIterator(
801
+ Params const &params, ///< Precomputed parameters object
802
+ Pointer pointer, ///< Pointer to start of tensor
803
+ TensorCoord extent, ///< Extent of tensor
804
+ int thread_id ///< ID of each participating thread
805
+ ): EllPredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
806
+
807
+ /// Adds a pointer offset in units of Element
808
+ CUTLASS_HOST_DEVICE
809
+ void add_pointer_offset(LongIndex pointer_offset) {
810
+ iterator_.add_pointer_offset(pointer_offset);
811
+ }
812
+
813
+ /// Advances to the next tile in memory.
814
+ ///
815
+ /// The first time this method is called, predicates are updated, and the iterator's
816
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
817
+ /// are lightweight and must only update the internal pointer.
818
+ CUTLASS_HOST_DEVICE
819
+ EllPredicatedTileIterator &operator++() {
820
+ ++iterator_;
821
+ return *this;
822
+ }
823
+
824
+ /// Advances to the next tile in memory.
825
+ ///
826
+ /// The first time this method is called, predicates are updated, and the iterator's
827
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
828
+ /// are lightweight and must only update the internal pointer.
829
+ CUTLASS_HOST_DEVICE
830
+ EllPredicatedTileIterator operator++(int) {
831
+ EllPredicatedTileIterator self(*this);
832
+ operator++();
833
+ return self;
834
+ }
835
+
836
+ /// Returns a stride
837
+ CUTLASS_HOST_DEVICE
838
+ int get_stride() const { return iterator_.get_stride(); }
839
+
840
+ /// Clears the predicate set efficiently
841
+ CUTLASS_HOST_DEVICE
842
+ void clear_mask(bool enable = true) {
843
+ iterator_.clear_mask(enable);
844
+ }
845
+
846
+ /// Clears the predicate set efficiently
847
+ CUTLASS_HOST_DEVICE
848
+ void enable_mask() {
849
+ iterator_.enable_mask();
850
+ }
851
+
852
+ /// Sets the predicate mask, overriding value stored in predicate iterator
853
+ CUTLASS_HOST_DEVICE
854
+ void set_mask(Mask const &mask) {
855
+ iterator_.set_mask(mask);
856
+ }
857
+
858
+ /// Gets the mask
859
+ CUTLASS_HOST_DEVICE
860
+ void get_mask(Mask &mask) {
861
+ iterator_.get_mask(mask);
862
+ }
863
+
864
+ /// add mask for small tiles in ELL
865
+ CUTLASS_HOST_DEVICE
866
+ void ell_add_mask(int blocksize) {
867
+ iterator_.ell_add_mask(blocksize);
868
+ }
869
+
870
+ /// Loads a fragment from memory
871
+ CUTLASS_DEVICE
872
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
873
+ iterator_.load_with_pointer_offset(frag, pointer_offset);
874
+ }
875
+
876
+ /// Loads a fragment from memory
877
+ CUTLASS_DEVICE
878
+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
879
+ iterator_.load_with_byte_offset(frag, byte_offset);
880
+ }
881
+
882
+ /// Loads a fragment from memory
883
+ CUTLASS_DEVICE
884
+ void load(Fragment &frag) {
885
+ load_with_pointer_offset(frag, 0);
886
+ }
887
+
888
+ CUTLASS_DEVICE
889
+ void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) {
890
+ iterator_.load_with_ell_index(frag, ell_iter);
891
+ }
892
+
893
+ CUTLASS_DEVICE
894
+ void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) {
895
+ iterator_.load_with_ell_index_fast(frag, ell_iter);
896
+ }
897
+
898
+ /// Store a fragment to memory
899
+ CUTLASS_DEVICE
900
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
901
+ iterator_.store_with_pointer_offset(frag, pointer_offset);
902
+ }
903
+
904
+ /// Store a fragment to memory
905
+ CUTLASS_DEVICE
906
+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
907
+ iterator_.store_with_byte_offset(frag, byte_offset);
908
+ }
909
+
910
+ /// Store a fragment to memory
911
+ CUTLASS_DEVICE
912
+ void store(Fragment const &frag) {
913
+ store_with_pointer_offset(frag, 0);
914
+ }
915
+ };
916
+
917
+ ////////////////////////////////////////////////////////////////////////////////
918
+
919
+ /// Specialization of EllPredicatedTileIterator for interleaved data. It is mapped
920
+ /// to the congruous layout.
921
+ ///
922
+ /// Satisfies: ForwardTileIteratorConcept |
923
+ /// ReadableContiguousTileIteratorConcept |
924
+ /// WriteableContiguousTileIteratorConcept |
925
+ /// MaskedTileIteratorConcept
926
+ ///
927
+
928
+ template <typename Shape_, typename Element_, int AdvanceRank,
929
+ typename ThreadMap_, int AccessSize, int InterleavedK>
930
+ class EllPredicatedTileIterator<Shape_, Element_,
931
+ layout::ColumnMajorInterleaved<InterleavedK>,
932
+ AdvanceRank, ThreadMap_, AccessSize> {
933
+ public:
934
+ static_assert(
935
+ AdvanceRank == 0 || AdvanceRank == 1,
936
+ "Specialization for pitch-linear iterator may along advance along the "
937
+ "contiguous(rank=0) or strided(rank=1) dimension.");
938
+
939
+ using Shape = Shape_;
940
+ using Element = Element_;
941
+ static int const kInterleavedK = InterleavedK;
942
+ using Layout = layout::ColumnMajorInterleaved<kInterleavedK>;
943
+ static int const kAdvanceRank = AdvanceRank;
944
+ using ThreadMap = ThreadMap_;
945
+
946
+ using Index = typename Layout::Index;
947
+ using LongIndex = typename Layout::LongIndex;
948
+
949
+ using TensorRef = TensorRef<Element, Layout>;
950
+ using TensorView = TensorView<Element, Layout>;
951
+ using TensorCoord = typename Layout::TensorCoord;
952
+
953
+ using Pointer = Element *;
954
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
955
+
956
+ using UnderlyingIterator = EllPredicatedTileIterator<
957
+ layout::PitchLinearShape<Shape::kRow * kInterleavedK,
958
+ Shape::kColumn / kInterleavedK>,
959
+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessSize>;
960
+
961
+
962
+ using AccessType = typename UnderlyingIterator::AccessType;
963
+
964
+ /// Fragment object to be loaded or stored
965
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
966
+ ThreadMap::kElementsPerAccess>;
967
+
968
+ /// Predicate vector stores mask to guard accesses
969
+ using Mask = typename UnderlyingIterator::Mask;
970
+
971
+ /// Iterator for ELL storage
972
+ using EllIterator = typename cutlass::transform::threadblock::ell::Iterator;
973
+
974
+ /// Parameters object is precomputed state and is host-constructible
975
+ class Params {
976
+ private:
977
+ friend EllPredicatedTileIterator;
978
+
979
+ /// Parameters object
980
+ typename UnderlyingIterator::Params params_;
981
+
982
+ public:
983
+ CUTLASS_HOST_DEVICE
984
+ Params() {}
985
+
986
+ /// Construct the Params object given a pitch-linear tensor's layout
987
+ CUTLASS_HOST_DEVICE
988
+ Params(Layout const &layout)
989
+ : params_(layout::PitchLinear(layout.stride(0))) {}
990
+ };
991
+
992
+ private:
993
+ //
994
+ // Data members
995
+ //
996
+
997
+ /// Underlying pitch-linear tile iterator
998
+ UnderlyingIterator iterator_;
999
+
1000
+ public:
1001
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
1002
+ /// and thread ID
1003
+ CUTLASS_HOST_DEVICE
1004
+ EllPredicatedTileIterator(
1005
+ /// Precomputed parameters object
1006
+ Params const &params,
1007
+ /// Pointer to start of tensor
1008
+ Pointer pointer,
1009
+ /// Extent of tensor
1010
+ TensorCoord extent,
1011
+ /// ID of each participating thread
1012
+ int thread_id,
1013
+ /// Initial offset of threadblock
1014
+ TensorCoord const &threadblock_offset)
1015
+ : iterator_(params.params_, pointer,
1016
+ layout::PitchLinearCoord(extent.row() * kInterleavedK,
1017
+ extent.column() / kInterleavedK),
1018
+ thread_id,
1019
+ layout::PitchLinearCoord(
1020
+ threadblock_offset.row() * kInterleavedK,
1021
+ threadblock_offset.column() / kInterleavedK)) {}
1022
+
1023
+ /// Construct a EllPredicatedTileIterator with zero threadblock offset
1024
+ CUTLASS_HOST_DEVICE
1025
+ EllPredicatedTileIterator(
1026
+ Params const &params, ///< Precomputed parameters object
1027
+ Pointer pointer, ///< Pointer to start of tensor
1028
+ TensorCoord extent, ///< Extent of tensor
1029
+ int thread_id ///< ID of each participating thread
1030
+ )
1031
+ : EllPredicatedTileIterator(params, pointer, extent, thread_id,
1032
+ make_Coord(0, 0)) {}
1033
+
1034
+ /// Adds a pointer offset in units of Element
1035
+ CUTLASS_HOST_DEVICE
1036
+ void add_pointer_offset(LongIndex pointer_offset) {
1037
+ iterator_.add_pointer_offset(pointer_offset);
1038
+ }
1039
+
1040
+ /// Advances to the next tile in memory.
1041
+ ///
1042
+ /// The first time this method is called, predicates are updated, and the
1043
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1044
+ /// Subsequent calls are lightweight and must only update the internal
1045
+ /// pointer.
1046
+ CUTLASS_HOST_DEVICE
1047
+ EllPredicatedTileIterator &operator++() {
1048
+ ++iterator_;
1049
+ return *this;
1050
+ }
1051
+
1052
+ /// Advances to the next tile in memory.
1053
+ ///
1054
+ /// The first time this method is called, predicates are updated, and the
1055
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1056
+ /// Subsequent calls are lightweight and must only update the internal
1057
+ /// pointer.
1058
+ CUTLASS_HOST_DEVICE
1059
+ EllPredicatedTileIterator operator++(int) {
1060
+ EllPredicatedTileIterator self(*this);
1061
+ operator++();
1062
+ return self;
1063
+ }
1064
+
1065
+ /// Returns a stride
1066
+ CUTLASS_HOST_DEVICE
1067
+ int get_stride() const { return iterator_.get_stride(); }
1068
+
1069
+ /// Clears the predicate set efficiently
1070
+ CUTLASS_HOST_DEVICE
1071
+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
1072
+
1073
+ /// Clears the predicate set efficiently
1074
+ CUTLASS_HOST_DEVICE
1075
+ void enable_mask() { iterator_.enable_mask(); }
1076
+
1077
+ /// Sets the predicate mask, overriding value stored in predicate iterator
1078
+ CUTLASS_HOST_DEVICE
1079
+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
1080
+
1081
+ /// Gets the mask
1082
+ CUTLASS_HOST_DEVICE
1083
+ void get_mask(Mask &mask) { iterator_.get_mask(mask); }
1084
+
1085
+ /// add mask for small tiles in ELL
1086
+ CUTLASS_HOST_DEVICE
1087
+ void ell_add_mask(int blocksize) { iterator_.ell_add_mask(blocksize); }
1088
+
1089
+ /// Loads a fragment from memory
1090
+ CUTLASS_DEVICE
1091
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
1092
+ iterator_.load_with_pointer_offset(frag, pointer_offset);
1093
+ }
1094
+
1095
+ CUTLASS_DEVICE
1096
+ void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) {
1097
+ iterator_.load_with_ell_index(frag, ell_iter);
1098
+ }
1099
+
1100
+ CUTLASS_DEVICE
1101
+ void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) {
1102
+ iterator_.load_with_ell_index_fast(frag, ell_iter);
1103
+ }
1104
+
1105
+ /// Loads a fragment from memory
1106
+ CUTLASS_DEVICE
1107
+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
1108
+
1109
+ /// Store a fragment to memory
1110
+ CUTLASS_DEVICE
1111
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
1112
+ iterator_.store_with_pointer_offset(frag, pointer_offset);
1113
+ }
1114
+
1115
+ /// Store a fragment to memory
1116
+ CUTLASS_DEVICE
1117
+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
1118
+ };
1119
+
1120
+ ////////////////////////////////////////////////////////////////////////////////
1121
+
1122
+ /// Specialization of EllPredicatedTileIterator for interleaved-32 data. It is
1123
+ /// mapped to the congruous layout.
1124
+ ///
1125
+ /// Satisfies: ForwardTileIteratorConcept |
1126
+ /// ReadableContiguousTileIteratorConcept |
1127
+ /// WriteableContiguousTileIteratorConcept |
1128
+ /// MaskedTileIteratorConcept
1129
+ ///
1130
+ template <typename Shape_, typename Element_, int AdvanceRank,
1131
+ typename ThreadMap_, int AccessSize, int InterleavedK>
1132
+ class EllPredicatedTileIterator<Shape_, Element_,
1133
+ layout::RowMajorInterleaved<InterleavedK>,
1134
+ AdvanceRank, ThreadMap_, AccessSize> {
1135
+ public:
1136
+ static_assert(
1137
+ AdvanceRank == 0 || AdvanceRank == 1,
1138
+ "Specialization for pitch-linear iterator may along advance along the "
1139
+ "contiguous(rank=0) or strided(rank=1) dimension.");
1140
+
1141
+ using Shape = Shape_;
1142
+ using Element = Element_;
1143
+ static int const kInterleavedK = InterleavedK;
1144
+ using Layout = layout::RowMajorInterleaved<kInterleavedK>;
1145
+ static int const kAdvanceRank = AdvanceRank;
1146
+ using ThreadMap = ThreadMap_;
1147
+
1148
+ using Index = typename Layout::Index;
1149
+ using LongIndex = typename Layout::LongIndex;
1150
+
1151
+ using TensorRef = TensorRef<Element, Layout>;
1152
+ using TensorView = TensorView<Element, Layout>;
1153
+ using TensorCoord = typename Layout::TensorCoord;
1154
+
1155
+ using Pointer = Element *;
1156
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
1157
+
1158
+ using UnderlyingIterator = EllPredicatedTileIterator<
1159
+ layout::PitchLinearShape<Shape::kColumn * kInterleavedK,
1160
+ Shape::kRow / kInterleavedK>,
1161
+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessSize>;
1162
+
1163
+
1164
+ using AccessType = typename UnderlyingIterator::AccessType;
1165
+
1166
+ /// Fragment object to be loaded or stored
1167
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
1168
+ ThreadMap::kElementsPerAccess>;
1169
+
1170
+ /// Predicate vector stores mask to guard accesses
1171
+ using Mask = typename UnderlyingIterator::Mask;
1172
+
1173
+ /// Parameters object is precomputed state and is host-constructible
1174
+ class Params {
1175
+ private:
1176
+ friend EllPredicatedTileIterator;
1177
+
1178
+ /// Parameters object
1179
+ typename UnderlyingIterator::Params params_;
1180
+
1181
+ public:
1182
+ CUTLASS_HOST_DEVICE
1183
+ Params() {}
1184
+
1185
+ /// Construct the Params object given a pitch-linear tensor's layout
1186
+ CUTLASS_HOST_DEVICE
1187
+ Params(Layout const &layout)
1188
+ : params_(layout::PitchLinear(layout.stride(0))) {}
1189
+ };
1190
+
1191
+ private:
1192
+ //
1193
+ // Data members
1194
+ //
1195
+
1196
+ /// Underlying pitch-linear tile iterator
1197
+ UnderlyingIterator iterator_;
1198
+
1199
+ public:
1200
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
1201
+ /// and thread ID
1202
+ CUTLASS_HOST_DEVICE
1203
+ EllPredicatedTileIterator(
1204
+ /// Precomputed parameters object
1205
+ Params const &params,
1206
+ /// Pointer to start of tensor
1207
+ Pointer pointer,
1208
+ /// Extent of tensor
1209
+ TensorCoord extent,
1210
+ /// ID of each participating thread
1211
+ int thread_id,
1212
+ /// Initial offset of threadblock
1213
+ TensorCoord const &threadblock_offset)
1214
+ : iterator_(params.params_, pointer,
1215
+ layout::PitchLinearCoord(extent.column() * kInterleavedK,
1216
+ extent.row() / kInterleavedK),
1217
+ thread_id,
1218
+ layout::PitchLinearCoord(
1219
+ threadblock_offset.column() * kInterleavedK,
1220
+ threadblock_offset.row() / kInterleavedK)) {}
1221
+
1222
+ /// Construct a EllPredicatedTileIterator with zero threadblock offset
1223
+ CUTLASS_HOST_DEVICE
1224
+ EllPredicatedTileIterator(
1225
+ Params const &params, ///< Precomputed parameters object
1226
+ Pointer pointer, ///< Pointer to start of tensor
1227
+ TensorCoord extent, ///< Extent of tensor
1228
+ int thread_id ///< ID of each participating thread
1229
+ )
1230
+ : EllPredicatedTileIterator(params, pointer, extent, thread_id,
1231
+ make_Coord(0, 0)) {}
1232
+
1233
+ /// Adds a pointer offset in units of Element
1234
+ CUTLASS_HOST_DEVICE
1235
+ void add_pointer_offset(LongIndex pointer_offset) {
1236
+ iterator_.add_pointer_offset(pointer_offset);
1237
+ }
1238
+
1239
+ /// Advances to the next tile in memory.
1240
+ ///
1241
+ /// The first time this method is called, predicates are updated, and the
1242
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1243
+ /// Subsequent calls are lightweight and must only update the internal
1244
+ /// pointer.
1245
+ CUTLASS_HOST_DEVICE
1246
+ EllPredicatedTileIterator &operator++() {
1247
+ ++iterator_;
1248
+ return *this;
1249
+ }
1250
+
1251
+ /// Advances to the next tile in memory.
1252
+ ///
1253
+ /// The first time this method is called, predicates are updated, and the
1254
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1255
+ /// Subsequent calls are lightweight and must only update the internal
1256
+ /// pointer.
1257
+ CUTLASS_HOST_DEVICE
1258
+ EllPredicatedTileIterator operator++(int) {
1259
+ EllPredicatedTileIterator self(*this);
1260
+ operator++();
1261
+ return self;
1262
+ }
1263
+
1264
+ /// Returns a stride
1265
+ CUTLASS_HOST_DEVICE
1266
+ int get_stride() const { return iterator_.get_stride(); }
1267
+
1268
+ /// Clears the predicate set efficiently
1269
+ CUTLASS_HOST_DEVICE
1270
+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
1271
+
1272
+ /// Clears the predicate set efficiently
1273
+ CUTLASS_HOST_DEVICE
1274
+ void enable_mask() { iterator_.enable_mask(); }
1275
+
1276
+ /// Sets the predicate mask, overriding value stored in predicate iterator
1277
+ CUTLASS_HOST_DEVICE
1278
+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
1279
+
1280
+ /// Gets the mask
1281
+ CUTLASS_HOST_DEVICE
1282
+ void get_mask(Mask &mask) { iterator_.get_mask(mask); }
1283
+
1284
+ /// add mask for small tiles in ELL
1285
+ CUTLASS_HOST_DEVICE
1286
+ void ell_add_mask(int blocksize) { iterator_.ell_add_mask(blocksize); }
1287
+
1288
+ /// Loads a fragment from memory
1289
+ CUTLASS_DEVICE
1290
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
1291
+ iterator_.load_with_pointer_offset(frag, pointer_offset);
1292
+ }
1293
+
1294
+ /// Loads a fragment from memory
1295
+ CUTLASS_DEVICE
1296
+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
1297
+
1298
+ /// Store a fragment to memory
1299
+ CUTLASS_DEVICE
1300
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
1301
+ iterator_.store_with_pointer_offset(frag, pointer_offset);
1302
+ }
1303
+
1304
+ /// Store a fragment to memory
1305
+ CUTLASS_DEVICE
1306
+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
1307
+ };
1308
+
1309
+ ////////////////////////////////////////////////////////////////////////////////
1310
+
1311
+ } // namespace threadblock
1312
+ } // namespace transform
1313
+ } // namespace cutlass
1314
+
1315
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Templates calculating the address and predicates to the load of scale and bias vectors.
34
+
35
+ This iterator uses masks to guard out-of-bounds accesses.
36
+
37
+ It can be used to load the gamma and beta vectors of layernorm which is loop variant.
38
+
39
+ A precomputed "Params" object minimizes the amount of state that must be
40
+ stored in registers, and integer addition is used to advance the pointer
41
+ through memory.
42
+ */
43
+
44
+ #pragma once
45
+
46
+ #include "cutlass/array.h"
47
+ #include "cutlass/coord.h"
48
+ #include "cutlass/cutlass.h"
49
+ #include "cutlass/layout/matrix.h"
50
+ #include "cutlass/layout/pitch_linear.h"
51
+ #include "cutlass/matrix_shape.h"
52
+ #include "cutlass/predicate_vector.h"
53
+ #include "cutlass/tensor_ref.h"
54
+ #include "cutlass/tensor_view.h"
55
+ #include "cutlass/conv/threadblock/conv2d_params.h"
56
+
57
+ ////////////////////////////////////////////////////////////////////////////////
58
+
59
+ namespace cutlass {
60
+ namespace transform {
61
+ namespace threadblock {
62
+
63
+ ////////////////////////////////////////////////////////////////////////////////
64
+
65
+ /// PredicatedScaleBiasVectorAccessIterator
66
+ ///
67
+ template <typename ThreadblockShape,
68
+ typename Element,
69
+ typename Layout>
70
+ class PredicatedScaleBiasVectorAccessIterator;
71
+
72
+ ////////////////////////////////////////////////////////////////////////////////
73
+
74
+ /// Specialization of PredicatedTileAccessIterator for fprop pitch-linear data.
75
+ ///
76
+ template <typename ThreadblockShape_, typename Element_>
77
+ class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
78
+ Element_,
79
+ layout::PitchLinear> {
80
+ public:
81
+
82
+ using ThreadblockShape = ThreadblockShape_;
83
+ using Element = Element_;
84
+ using Layout = layout::PitchLinear;
85
+
86
+ using Index = typename Layout::Index;
87
+ using LongIndex = typename Layout::LongIndex;
88
+
89
+ using TensorRef = TensorRef<Element, Layout>;
90
+ using TensorView = TensorView<Element, Layout>;
91
+ using TensorCoord = typename Layout::TensorCoord;
92
+
93
+ using ConstPointer = const Element *;
94
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
95
+
96
+ static int const kElementsPerAccess = 128 / sizeof_bits<Element>::value;
97
+ static int const kThreads = ThreadblockShape::kContiguous / kElementsPerAccess;
98
+
99
+ using AccessType = AlignedArray<Element, kElementsPerAccess>;
100
+
101
+ private:
102
+ /// Internal pointer type permits fast address arithmetic
103
+ using BytePointer = char *;
104
+
105
+ private:
106
+ //
107
+ // Data members
108
+ //
109
+
110
+ /// Internal pointer to first access of tile
111
+ BytePointer pointer_;
112
+
113
+ TensorCoord thread_offset_;
114
+
115
+ int problem_size_k_;
116
+
117
+ /// Used for out-of-order visitation
118
+ bool is_residue_tile_;
119
+
120
+ bool guard_;
121
+
122
+ TensorCoord::Index residue_size_;
123
+
124
+ public:
125
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
126
+ /// and thread ID
127
+ CUTLASS_HOST_DEVICE
128
+ PredicatedScaleBiasVectorAccessIterator(
129
+ /// Extent of tensor
130
+ int problem_size_k,
131
+ /// Pointer to the start of the scale vector
132
+ ConstPointer scale_pointer,
133
+ /// Pointer to the start of the bias vector
134
+ ConstPointer bias_pointer,
135
+ /// ID of each participating thread
136
+ int thread_id,
137
+ /// Initial offset of threadblock
138
+ TensorCoord const &threadblock_offset) {
139
+ pointer_ = (thread_id < kThreads)
140
+ ? reinterpret_cast<BytePointer>(
141
+ const_cast<NonConstPointer>(scale_pointer))
142
+ : reinterpret_cast<BytePointer>(
143
+ const_cast<NonConstPointer>(bias_pointer));
144
+
145
+ // Per-thread offset in logical coordinates of tensor
146
+ int thread_base = (thread_id < kThreads) ? 0 : kThreads;
147
+
148
+ problem_size_k_ = problem_size_k;
149
+
150
+ is_residue_tile_ = true;
151
+
152
+ residue_size_ = (problem_size_k_ - threadblock_offset.contiguous()) % ThreadblockShape::kContiguous;
153
+
154
+ if (residue_size_ == 0) {
155
+ residue_size_ = ThreadblockShape::kContiguous;
156
+ }
157
+
158
+ guard_ = ((thread_id - thread_base) * kElementsPerAccess) < residue_size_;
159
+
160
+ thread_offset_ =
161
+ threadblock_offset +
162
+ TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0);
163
+
164
+ set_iteration_index(0);
165
+ }
166
+
167
+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset
168
+ CUTLASS_HOST_DEVICE
169
+ PredicatedScaleBiasVectorAccessIterator(
170
+ /// Extent of tensor
171
+ int problem_size_k,
172
+ /// Pointer to start of scale vector
173
+ ConstPointer scale_pointer,
174
+ /// Pointer to start of scale vector
175
+ ConstPointer bias_pointer,
176
+ ///< ID of each participating thread
177
+ int thread_id)
178
+ : PredicatedScaleBiasVectorAccessIterator(problem_size_k,
179
+ scale_pointer, bias_pointer,
180
+ thread_id, make_Coord(0, 0)) {}
181
+
182
+ /// Overrides the internal iteration index
183
+ CUTLASS_HOST_DEVICE
184
+ void set_iteration_index(int index) {}
185
+
186
+ /// Advances an iterator along logical dimensions of matrix in units of whole threadblock tiles
187
+ CUTLASS_DEVICE
188
+ void add_tile_offset(
189
+ TensorCoord const &tile_offset) {
190
+
191
+ guard_ = threadIdx.x < kThreads * 2;
192
+
193
+ TensorCoord offset = is_residue_tile_ ?
194
+ TensorCoord(residue_size_ + ThreadblockShape::kContiguous * (tile_offset.contiguous() - 1), 0)
195
+ : TensorCoord(ThreadblockShape::kContiguous * tile_offset.contiguous(), 0);
196
+
197
+ thread_offset_ =
198
+ thread_offset_ +
199
+ offset;
200
+
201
+ is_residue_tile_ = false;
202
+ }
203
+
204
+ /// Returns a pointer
205
+ CUTLASS_HOST_DEVICE
206
+ AccessType *get() const {
207
+
208
+ return reinterpret_cast<AccessType *>(
209
+ pointer_ +
210
+ (thread_offset_.contiguous() * sizeof_bits<Element>::value / 8));
211
+ }
212
+
213
+ /// Increment and return an instance to self.
214
+ CUTLASS_HOST_DEVICE
215
+ PredicatedScaleBiasVectorAccessIterator &operator++() {
216
+ return *this;
217
+ }
218
+
219
+ /// Increment and return an instance to self.
220
+ CUTLASS_DEVICE
221
+ PredicatedScaleBiasVectorAccessIterator operator++(int) {
222
+ PredicatedScaleBiasVectorAccessIterator self(*this);
223
+ operator++();
224
+ return self;
225
+ }
226
+
227
+ /// Clears the predicate set efficiently
228
+ CUTLASS_HOST_DEVICE
229
+ void clear_mask(bool enable = true) {
230
+ guard_ &= (!enable);
231
+ }
232
+
233
+ /// Returns whether access is valid or not
234
+ CUTLASS_HOST_DEVICE
235
+ bool valid() {
236
+ return guard_;
237
+ }
238
+ };
239
+
240
+ ////////////////////////////////////////////////////////////////////////////////
241
+
242
+ /// Specialization of PredicatedTileAccessIterator for row-major data.
243
+ ///
244
+ /// Satisfies: ForwardTileIteratorConcept |
245
+ /// ReadableContiguousTileIteratorConcept |
246
+ /// WriteableContiguousTileIteratorConcept |
247
+ /// MaskedTileIteratorConcept
248
+ ///
249
+ template <typename ThreadblockShape_,
250
+ typename Element_>
251
+ class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
252
+ Element_,
253
+ layout::RowMajor> {
254
+ public:
255
+
256
+ using ThreadblockShape = ThreadblockShape_;
257
+ using Element = Element_;
258
+ using Layout = layout::RowMajor;
259
+
260
+ using Index = typename Layout::Index;
261
+ using LongIndex = typename Layout::LongIndex;
262
+
263
+ using TensorRef = TensorRef<Element, Layout>;
264
+ using TensorView = TensorView<Element, Layout>;
265
+ using TensorCoord = typename Layout::TensorCoord;
266
+
267
+ using ConstPointer = const Element *;
268
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
269
+
270
+ using UnderlyingIterator = PredicatedScaleBiasVectorAccessIterator<
271
+ layout::PitchLinearShape<ThreadblockShape::kColumn, ThreadblockShape::kRow>,
272
+ Element,
273
+ layout::PitchLinear>;
274
+
275
+ using AccessType = typename UnderlyingIterator::AccessType;
276
+ static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess;
277
+
278
+ private:
279
+ //
280
+ // Data members
281
+ //
282
+
283
+ /// Underlying pitch-linear tile iterator
284
+ UnderlyingIterator iterator_;
285
+
286
+ public:
287
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
288
+ /// and thread ID
289
+ CUTLASS_HOST_DEVICE
290
+ PredicatedScaleBiasVectorAccessIterator(
291
+ ///< Extent of tensor
292
+ int problem_size_k,
293
+ ///< Pointer to the start of the scale vector
294
+ ConstPointer scale_pointer,
295
+ ///< Pointer to the start of the bias vector
296
+ ConstPointer bias_pointer,
297
+ ///< ID of each participating thread
298
+ int thread_id,
299
+ ///< Initial offset of threadblock
300
+ TensorCoord const &threadblock_offset)
301
+ : iterator_(problem_size_k, scale_pointer, bias_pointer,
302
+ thread_id,
303
+ layout::PitchLinearCoord(threadblock_offset.column(),
304
+ threadblock_offset.row())) {}
305
+
306
+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset
307
+ CUTLASS_HOST_DEVICE
308
+ PredicatedScaleBiasVectorAccessIterator(
309
+ int problem_size_k, ///< Extent of tensor
310
+ ConstPointer scale_pointer, ///< Pointer to the start of the scale vector
311
+ ConstPointer bias_pointer, ///< Pointer to the start of the bias vector
312
+ int thread_id ///< ID of each participating thread
313
+ )
314
+ : PredicatedScaleBiasVectorAccessIterator(problem_size_k,
315
+ scale_pointer, bias_pointer,
316
+ thread_id, make_Coord(0, 0)) {}
317
+
318
+ /// Advances an iterator along logical dimensions of matrix in units of whole
319
+ /// threadblock tiles
320
+ CUTLASS_HOST_DEVICE
321
+ void add_tile_offset(TensorCoord const &tile_offset) {
322
+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
323
+ }
324
+
325
+ /// Returns a pointer
326
+ CUTLASS_HOST_DEVICE
327
+ AccessType *get() const {
328
+ return reinterpret_cast<AccessType *>(iterator_.get());
329
+ }
330
+
331
+ /// Advances to the next tile in memory.
332
+ ///
333
+ /// The first time this method is called, predicates are updated, and the
334
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
335
+ /// Subsequent calls are lightweight and must only update the internal
336
+ /// pointer.
337
+ CUTLASS_HOST_DEVICE
338
+ PredicatedScaleBiasVectorAccessIterator &operator++() {
339
+ ++iterator_;
340
+ return *this;
341
+ }
342
+
343
+ /// Advances to the next tile in memory.
344
+ ///
345
+ /// The first time this method is called, predicates are updated, and the
346
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
347
+ /// Subsequent calls are lightweight and must only update the internal
348
+ /// pointer.
349
+ CUTLASS_HOST_DEVICE
350
+ PredicatedScaleBiasVectorAccessIterator operator++(int) {
351
+ PredicatedScaleBiasVectorAccessIterator self(*this);
352
+ operator++();
353
+ return self;
354
+ }
355
+
356
+ /// Clears the predicate set efficiently
357
+ CUTLASS_HOST_DEVICE
358
+ void clear_mask(bool enable = true) {
359
+ iterator_.clear_mask(enable);
360
+ }
361
+
362
+ /// Returns whether access is valid or not
363
+ CUTLASS_HOST_DEVICE
364
+ bool valid() {
365
+ return iterator_.valid();
366
+ }
367
+ };
368
+
369
+ ////////////////////////////////////////////////////////////////////////////////
370
+
371
+ } // namespace threadblock
372
+ } // namespace transform
373
+ } // namespace cutlass
374
+
375
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Templates calculating the address and predicates to the load of scale and bias vectors.
34
+
35
+ This iterator uses masks to guard out-of-bounds accesses.
36
+
37
+ This can be used to load var and mean vectors in layernorm which is loop invariant.
38
+
39
+ A precomputed "Params" object minimizes the amount of state that must be
40
+ stored in registers, and integer addition is used to advance the pointer
41
+ through memory.
42
+ */
43
+
44
+ #pragma once
45
+
46
+ #include "cutlass/array.h"
47
+ #include "cutlass/coord.h"
48
+ #include "cutlass/cutlass.h"
49
+ #include "cutlass/layout/matrix.h"
50
+ #include "cutlass/layout/pitch_linear.h"
51
+ #include "cutlass/matrix_shape.h"
52
+ #include "cutlass/predicate_vector.h"
53
+ #include "cutlass/tensor_ref.h"
54
+ #include "cutlass/tensor_view.h"
55
+
56
+ ////////////////////////////////////////////////////////////////////////////////
57
+
58
+ namespace cutlass {
59
+ namespace transform {
60
+ namespace threadblock {
61
+
62
+ ////////////////////////////////////////////////////////////////////////////////
63
+
64
+ /// PredicatedScaleBiasVectorIterator
65
+ ///
66
+ template <typename WarpShape,
67
+ typename Element,
68
+ typename Layout>
69
+ class PredicatedScaleBiasVectorIterator;
70
+
71
+ ////////////////////////////////////////////////////////////////////////////////
72
+
73
+ /// Specialization of PredicatedTileIterator for wgrad pitch-linear data.
74
+ ///
75
+ template <typename WarpShape_, typename Element_>
76
+ class PredicatedScaleBiasVectorIterator<WarpShape_,
77
+ Element_,
78
+ layout::PitchLinear> {
79
+ public:
80
+
81
+ using WarpShape = WarpShape_;
82
+ using Element = Element_;
83
+ using Layout = layout::PitchLinear;
84
+
85
+ using Index = typename Layout::Index;
86
+ using LongIndex = typename Layout::LongIndex;
87
+
88
+ using TensorRef = TensorRef<Element, Layout>;
89
+ using TensorView = TensorView<Element, Layout>;
90
+ using TensorCoord = typename Layout::TensorCoord;
91
+
92
+ using ConstPointer = const Element *;
93
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
94
+
95
+ static int const kElementsPerAccess = 1;
96
+
97
+ using AccessType = AlignedArray<Element, kElementsPerAccess>;
98
+
99
+ static int const kIterations = WarpShape::kContiguous / 8;
100
+
101
+ /// Fragment object to be loaded or stored
102
+ using Fragment = cutlass::Array<__half2, 2 * kIterations * kElementsPerAccess>;
103
+
104
+ private:
105
+ //
106
+ // Data members
107
+ //
108
+
109
+ /// Internal pointer to first access of tile
110
+ ConstPointer scale_pointer_;
111
+ ConstPointer bias_pointer_;
112
+
113
+ /// Size of tensor
114
+ int problem_size_;
115
+
116
+ int32_t thread_offset_;
117
+
118
+ public:
119
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
120
+ /// and thread ID
121
+ CUTLASS_HOST_DEVICE
122
+ PredicatedScaleBiasVectorIterator(
123
+ /// Extent of tensor
124
+ int problem_size,
125
+ /// Pointer to the start of the scale vector
126
+ ConstPointer scale_pointer,
127
+ /// Pointer to the start of the bias vector
128
+ ConstPointer bias_pointer,
129
+ /// ID of each participating thread
130
+ int thread_id,
131
+ /// Initial offset of threadblock
132
+ TensorCoord const &threadblock_offset)
133
+ : problem_size_(problem_size),
134
+ scale_pointer_(scale_pointer),
135
+ bias_pointer_(bias_pointer) {
136
+
137
+ thread_offset_ = threadblock_offset.contiguous() + (thread_id % 32) / 4;
138
+ }
139
+
140
+ /// Construct a PredicatedTileIterator with zero threadblock offset
141
+ CUTLASS_HOST_DEVICE
142
+ PredicatedScaleBiasVectorIterator(
143
+ /// Extent of tensor
144
+ int problem_size,
145
+ /// Pointer to start of scale vector
146
+ ConstPointer scale_pointer,
147
+ /// Pointer to start of scale vector
148
+ ConstPointer bias_pointer,
149
+ ///< ID of each participating thread
150
+ int thread_id)
151
+ : PredicatedScaleBiasVectorIterator(problem_size,
152
+ scale_pointer, bias_pointer,
153
+ thread_id, make_Coord(0, 0)) {}
154
+
155
+ /// Advances an iterator along logical dimensions of matrix in units of whole warp tiles
156
+ CUTLASS_DEVICE
157
+ void add_tile_offset(
158
+ TensorCoord const &tile_offset) {
159
+
160
+ thread_offset_ += (WarpShape::kContiguous * tile_offset.contiguous());
161
+ }
162
+
163
+ /// Loads a fragment from memory
164
+ CUTLASS_DEVICE
165
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
166
+
167
+ frag.fill(__float2half2_rn(0.0f));
168
+ __half2 *frag_ptr = reinterpret_cast<__half2 *>(&frag);
169
+
170
+ // load scale
171
+ CUTLASS_PRAGMA_UNROLL
172
+ for (int c = 0; c < kIterations; ++c) {
173
+
174
+ cutlass::arch::global_load<
175
+ __half,
176
+ sizeof(AccessType)
177
+ >(
178
+ frag_ptr[c * 2].x,
179
+ scale_pointer_ + thread_offset_ + c * 8,
180
+ (thread_offset_ + c * 8) < problem_size_
181
+ );
182
+ }
183
+
184
+ // load bias
185
+ CUTLASS_PRAGMA_UNROLL
186
+ for (int c = 0; c < kIterations; ++c) {
187
+
188
+ cutlass::arch::global_load<
189
+ __half,
190
+ sizeof(AccessType)
191
+ >(
192
+ frag_ptr[c * 2 + 1].x,
193
+ bias_pointer_ + thread_offset_ + c * 8,
194
+ (thread_offset_ + c * 8) < problem_size_
195
+ );
196
+ }
197
+
198
+ // duplicate scale
199
+ CUTLASS_PRAGMA_UNROLL
200
+ for (int c = 0; c < kIterations; ++c) {
201
+ frag_ptr[c * 2].y = frag_ptr[c * 2].x;
202
+ }
203
+
204
+ // duplicate bias
205
+ CUTLASS_PRAGMA_UNROLL
206
+ for (int c = 0; c < kIterations; ++c) {
207
+ frag_ptr[c * 2 + 1].y = frag_ptr[c * 2 + 1].x;
208
+ }
209
+ }
210
+
211
+ /// Loads a fragment from memory
212
+ CUTLASS_DEVICE
213
+ void load(Fragment &frag) {
214
+ load_with_pointer_offset(frag, 0);
215
+ }
216
+ };
217
+
218
+ ////////////////////////////////////////////////////////////////////////////////
219
+
220
+ /// Specialization of PredicatedTileIterator for row-major data.
221
+ ///
222
+ /// Satisfies: ForwardTileIteratorConcept |
223
+ /// ReadableContiguousTileIteratorConcept |
224
+ /// WriteableContiguousTileIteratorConcept |
225
+ /// MaskedTileIteratorConcept
226
+ ///
227
+ template <typename WarpShape_,
228
+ typename Element_>
229
+ class PredicatedScaleBiasVectorIterator<WarpShape_,
230
+ Element_,
231
+ layout::RowMajor> {
232
+ public:
233
+
234
+ using WarpShape = WarpShape_;
235
+ using Element = Element_;
236
+ using Layout = layout::RowMajor;
237
+
238
+ using Index = typename Layout::Index;
239
+ using LongIndex = typename Layout::LongIndex;
240
+
241
+ using TensorRef = TensorRef<Element, Layout>;
242
+ using TensorView = TensorView<Element, Layout>;
243
+ using TensorCoord = typename Layout::TensorCoord;
244
+
245
+ using ConstPointer = const Element *;
246
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
247
+
248
+ using UnderlyingIterator = PredicatedScaleBiasVectorIterator<
249
+ layout::PitchLinearShape<WarpShape::kColumn, WarpShape::kRow>,
250
+ Element,
251
+ layout::PitchLinear>;
252
+
253
+ using AccessType = typename UnderlyingIterator::AccessType;
254
+ static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess;
255
+ using Fragment = typename UnderlyingIterator::Fragment;
256
+
257
+
258
+ private:
259
+ //
260
+ // Data members
261
+ //
262
+
263
+ /// Underlying pitch-linear tile iterator
264
+ UnderlyingIterator iterator_;
265
+
266
+ public:
267
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
268
+ /// and thread ID
269
+ CUTLASS_HOST_DEVICE
270
+ PredicatedScaleBiasVectorIterator(
271
+ ///< Extent of tensor
272
+ int problem_size,
273
+ ///< Pointer to the start of the scale vector
274
+ ConstPointer scale_pointer,
275
+ ///< Pointer to the start of the bias vector
276
+ ConstPointer bias_pointer,
277
+ ///< ID of each participating thread
278
+ int thread_id,
279
+ ///< Initial offset of threadblock
280
+ TensorCoord const &threadblock_offset)
281
+ : iterator_(problem_size, scale_pointer, bias_pointer,
282
+ thread_id,
283
+ layout::PitchLinearCoord(threadblock_offset.column(),
284
+ threadblock_offset.row())) {}
285
+
286
+ /// Construct a PredicatedTileIterator with zero threadblock offset
287
+ CUTLASS_HOST_DEVICE
288
+ PredicatedScaleBiasVectorIterator(
289
+ int problem_size, ///< Extent of tensor
290
+ ConstPointer scale_pointer, ///< Pointer to the start of the scale vector
291
+ ConstPointer bias_pointer, ///< Pointer to the start of the bias vector
292
+ int thread_id ///< ID of each participating thread
293
+ )
294
+ : PredicatedScaleBiasVectorIterator(problem_size,
295
+ scale_pointer, bias_pointer,
296
+ thread_id, make_Coord(0, 0)) {}
297
+
298
+ /// Overrides the internal iteration index
299
+ CUTLASS_HOST_DEVICE
300
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
301
+
302
+ /// Advances an iterator along logical dimensions of matrix in units of whole
303
+ /// threadblock tiles
304
+ CUTLASS_HOST_DEVICE
305
+ void add_tile_offset(TensorCoord const &tile_offset) {
306
+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
307
+ }
308
+
309
+ /// Loads a fragment from memory
310
+ CUTLASS_DEVICE
311
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
312
+ iterator_.load_with_pointer_offset(frag, pointer_offset);
313
+ }
314
+
315
+ /// Loads a fragment from memory
316
+ CUTLASS_DEVICE
317
+ void load(Fragment &frag) {
318
+ iterator_.load(frag);
319
+ }
320
+ };
321
+
322
+ ////////////////////////////////////////////////////////////////////////////////
323
+
324
+ } // namespace threadblock
325
+ } // namespace transform
326
+ } // namespace cutlass
327
+
328
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h ADDED
@@ -0,0 +1,2118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates calculating the address and predicates to the load of tiles
33
+ from pitch-linear rank=2 tensors.
34
+
35
+ This iterator uses masks to guard out-of-bounds accesses. The first tile this
36
+ iterator visits maybe partial, then the remaining tiles are complete. So, we
37
+ only need to compute the predicates twice, once before the first tile and
38
+ once for the remaining full tiles which can share the same predicates.
39
+
40
+ A precomputed "Params" object minimizes the amount of state that must be
41
+ stored in registers, and integer addition is used to advance the pointer
42
+ through memory.
43
+ */
44
+
45
+ #pragma once
46
+
47
+ #include "cutlass/array.h"
48
+ #include "cutlass/coord.h"
49
+ #include "cutlass/cutlass.h"
50
+ #include "cutlass/layout/matrix.h"
51
+ #include "cutlass/layout/permute.h"
52
+ #include "cutlass/layout/pitch_linear.h"
53
+ #include "cutlass/matrix_shape.h"
54
+ #include "cutlass/predicate_vector.h"
55
+ #include "cutlass/tensor_ref.h"
56
+ #include "cutlass/tensor_view.h"
57
+ #include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h"
58
+
59
+ ////////////////////////////////////////////////////////////////////////////////
60
+
61
+ ////////////////////////////////////////////////////////////////////////////////
62
+
63
+ namespace cutlass {
64
+ namespace transform {
65
+ namespace threadblock {
66
+
67
+ ////////////////////////////////////////////////////////////////////////////////
68
+
69
+ /// PredicatedTileAccessIteratorPredicates
70
+ ///
71
+ template <typename Shape_, typename Element_, typename Layout_, int AdvanceRank,
72
+ typename ThreadMap_, typename AccessType_>
73
+ class PredicatedTileAccessIteratorPredicates {
74
+ public:
75
+ using Shape = Shape_;
76
+ using Element = Element_;
77
+ using Layout = Layout_;
78
+ static int const kAdvanceRank = AdvanceRank;
79
+ using ThreadMap = ThreadMap_;
80
+ using AccessType = AccessType_;
81
+
82
+ using Index = typename Layout::Index;
83
+ using LongIndex = typename Layout::LongIndex;
84
+
85
+ using TensorCoord = typename Layout::TensorCoord;
86
+
87
+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
88
+
89
+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
90
+ "Vectors implied by the thread map must be divisible by the access type.");
91
+
92
+ static int const kPredicatesPerByte = 4;
93
+ static int const kPredicatesPerWord = 4 * kPredicatesPerByte;
94
+
95
+ static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector;
96
+
97
+ /// Number of 32b words containing predicates
98
+ static int const kPredicateByteCount =
99
+ (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte;
100
+ static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4;
101
+
102
+ static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u;
103
+
104
+ static_assert(kPredicateWordCount <= 4, "Too many predicates.");
105
+
106
+ /// Predicate vector stores mask to guard accesses
107
+ using Mask = Array<uint32_t, kPredicateWordCount>;
108
+
109
+ // private:
110
+ /// Guard predicates
111
+ uint32_t predicates_[kPredicateWordCount];
112
+
113
+ /// Size of tensor
114
+ TensorCoord extent_;
115
+
116
+ /// Initial offset for each thread
117
+ TensorCoord thread_offset_;
118
+
119
+ /// Offset to the first steady-state tile
120
+ TensorCoord residue_offset_;
121
+
122
+ /// Iteration along vectors implied by the thread map
123
+ int iteration_vector_;
124
+
125
+ /// Iteration in the contiguous dimension
126
+ int iteration_contiguous_;
127
+
128
+ /// Iteration in the strided dimension
129
+ int iteration_strided_;
130
+
131
+ public:
132
+ /// Computes predicates based on internally tracked per-thread offset.
133
+ CUTLASS_DEVICE
134
+ void compute_predicates_(
135
+ /// Extent of the matrix window
136
+ TensorCoord extent,
137
+ /// optionally, simplify predicate calculation during 'steady state' phase
138
+ bool is_steady_state = false) {
139
+
140
+ CUTLASS_PRAGMA_UNROLL
141
+ for (int i = 0; i < kPredicateWordCount; ++i) {
142
+ predicates_[i] = 0u;
143
+ }
144
+
145
+ CUTLASS_PRAGMA_UNROLL
146
+ for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) {
147
+
148
+ int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
149
+
150
+ int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
151
+
152
+ int c = access_residual / kAccessesPerVector;
153
+ int v = access_residual % kAccessesPerVector;
154
+
155
+ TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements,
156
+ s * ThreadMap::Delta::kStrided);
157
+
158
+ TensorCoord coord = thread_offset_ + iteration_coord;
159
+
160
+ bool guard;
161
+
162
+ if (is_steady_state) {
163
+ if (kAdvanceRank == 0) {
164
+ guard = (coord.strided() < extent.strided());
165
+ } else {
166
+ guard = (coord.contiguous() < extent.contiguous());
167
+ }
168
+ } else {
169
+ guard = (coord.strided() < extent.strided() &&
170
+ coord.contiguous() < extent.contiguous());
171
+ }
172
+
173
+ int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s);
174
+
175
+ int word_idx = pred_idx / kPredicatesPerWord;
176
+ int residual = pred_idx % kPredicatesPerWord;
177
+ int byte_idx = residual / kPredicatesPerByte;
178
+ int bit_idx = residual % kPredicatesPerByte;
179
+
180
+ predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx));
181
+
182
+ }
183
+
184
+ }
185
+
186
+ CUTLASS_HOST_DEVICE
187
+ void set_predicates(int thread_id, TensorCoord const &threadblock_offset) {
188
+
189
+ TensorCoord residue_extent;
190
+ if (kAdvanceRank) {
191
+
192
+ typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided;
193
+ if (!residue_size) {
194
+ residue_size = Shape::kStrided;
195
+ }
196
+
197
+ residue_offset_ = make_Coord(0, residue_size);
198
+ residue_extent = make_Coord(
199
+ extent_.contiguous(),
200
+ min(threadblock_offset.strided() + residue_size, extent_.strided())
201
+ );
202
+ } else {
203
+
204
+ typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous;
205
+ if (!residue_size) {
206
+ residue_size = Shape::kContiguous;
207
+ }
208
+
209
+ residue_offset_ = make_Coord(residue_size, 0);
210
+
211
+ residue_extent = make_Coord(
212
+ min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size),
213
+ extent_.strided()
214
+ );
215
+ }
216
+
217
+ // Per-thread offset in logical coordinates of tensor
218
+ thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id);
219
+
220
+ compute_predicates_(residue_extent, false);
221
+
222
+ set_iteration_index(0);
223
+ }
224
+
225
+ /// Default constructor
226
+ PredicatedTileAccessIteratorPredicates() = default;
227
+
228
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
229
+ /// and thread ID
230
+ CUTLASS_HOST_DEVICE
231
+ PredicatedTileAccessIteratorPredicates(
232
+ /// Extent of tensor
233
+ TensorCoord extent)
234
+ : extent_(extent) {
235
+ }
236
+
237
+ /// Overrides the internal iteration index
238
+ CUTLASS_HOST_DEVICE
239
+ void set_iteration_index(int index) {
240
+
241
+ iteration_vector_ = index % kAccessesPerVector;
242
+ int residual_access = index / kAccessesPerVector;
243
+
244
+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
245
+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
246
+
247
+ }
248
+
249
+ /// Increment and return an instance to self.
250
+ CUTLASS_HOST_DEVICE
251
+ PredicatedTileAccessIteratorPredicates &operator++() {
252
+
253
+ return *this;
254
+ }
255
+
256
+ /// Clears the predicate set efficiently
257
+ CUTLASS_HOST_DEVICE
258
+ void clear_mask(bool enable = true) {
259
+ CUTLASS_PRAGMA_UNROLL
260
+ for (int i = 0; i < kPredicateWordCount; ++i) {
261
+ predicates_[i] = enable ? 0u : predicates_[i];
262
+ }
263
+
264
+ }
265
+
266
+ /// Clears the predicate set efficiently
267
+ CUTLASS_HOST_DEVICE
268
+ void enable_mask() {
269
+ CUTLASS_PRAGMA_UNROLL
270
+ for (int i = 0; i < kPredicateWordCount; ++i) {
271
+ predicates_[i] = 0xffffffff;
272
+ }
273
+ }
274
+
275
+ /// Sets the predicate mask, overriding value stored in predicate iterator
276
+ CUTLASS_HOST_DEVICE
277
+ void set_mask(Mask const &mask) {
278
+ CUTLASS_PRAGMA_UNROLL
279
+ for (int i = 0; i < kPredicateWordCount; ++i) {
280
+ predicates_[i] = mask[i];
281
+ }
282
+
283
+ }
284
+
285
+ /// Gets the mask
286
+ CUTLASS_HOST_DEVICE
287
+ void get_mask(Mask &mask) {
288
+ CUTLASS_PRAGMA_UNROLL
289
+ for (int i = 0; i < kPredicateWordCount; ++i) {
290
+ mask[i] = predicates_[i];
291
+ }
292
+ }
293
+
294
+ /// Returns whether access is valid or not
295
+ CUTLASS_HOST_DEVICE
296
+ bool valid() const {
297
+
298
+
299
+ int pred_idx =
300
+ iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous);
301
+
302
+ int word_idx = pred_idx / kPredicatesPerWord;
303
+ int residual = pred_idx % kPredicatesPerWord;
304
+ int byte_idx = residual / kPredicatesPerByte;
305
+ int bit_idx = residual % kPredicatesPerByte;
306
+
307
+ bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0;
308
+ return pred;
309
+
310
+ }
311
+ };
312
+
313
+ ////////////////////////////////////////////////////////////////////////////////
314
+
315
+ /// PredicatedTileAccessIterator
316
+ ///
317
+ template <typename Shape, typename Element, typename Layout, int AdvanceRank,
318
+ typename ThreadMap, typename AccessType, bool Gather = false,
319
+ typename PermuteLayout = layout::NoPermute>
320
+ class PredicatedTileAccessIterator;
321
+
322
+ ////////////////////////////////////////////////////////////////////////////////
323
+
324
+ /// Specialization of PredicatedTileAccessIterator for pitch-linear data.
325
+ ///
326
+ template <typename Shape_, typename Element_, int AdvanceRank,
327
+ typename ThreadMap_, typename AccessType_, bool Gather,
328
+ typename PermuteLayout>
329
+ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
330
+ AdvanceRank, ThreadMap_, AccessType_, Gather,
331
+ PermuteLayout> {
332
+ public:
333
+ static_assert(
334
+ AdvanceRank == 0 || AdvanceRank == 1,
335
+ "Specialization for pitch-linear iterator may along advance along the "
336
+ "contiguous(rank=0) or strided(rank=1) dimension.");
337
+
338
+ using Shape = Shape_;
339
+ using Element = Element_;
340
+ using Layout = layout::PitchLinear;
341
+ static int const kAdvanceRank = AdvanceRank;
342
+ using ThreadMap = ThreadMap_;
343
+ using AccessType = AccessType_;
344
+
345
+ using Index = typename Layout::Index;
346
+ using LongIndex = typename Layout::LongIndex;
347
+
348
+ using TensorRef = TensorRef<Element, Layout>;
349
+ using TensorView = TensorView<Element, Layout>;
350
+ using TensorCoord = typename Layout::TensorCoord;
351
+
352
+ using Pointer = Element *;
353
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
354
+
355
+ using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates<
356
+ Shape, Element, Layout, AdvanceRank, ThreadMap, AccessType>;
357
+
358
+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
359
+
360
+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
361
+ "Vectors implied by the thread map must be divisible by the access type.");
362
+
363
+ static bool constexpr Permute = !platform::is_same<PermuteLayout, layout::NoPermute>::value
364
+ && !platform::is_same<PermuteLayout, layout::InversePermute<layout::NoPermute>>::value;
365
+
366
+ using Mask = typename UnderlyingPredicates::Mask;
367
+
368
+ /// Uses a non-template class
369
+ struct Params : PredicatedTileAccessIteratorParams {
370
+
371
+ using Base = PredicatedTileAccessIteratorParams;
372
+
373
+ /// Default constructor
374
+ Params() = default;
375
+
376
+ /// Construct the Params object given a pitch-linear tensor's layout
377
+ CUTLASS_HOST_DEVICE
378
+ Params(Layout const &layout) :
379
+ Base(layout.stride(0),
380
+ MakePredicatedTileAccessIteratorDesc<Shape, Element, Layout, kAdvanceRank, ThreadMap>()()
381
+ ) { }
382
+
383
+ CUTLASS_HOST_DEVICE
384
+ Params(Base const &base) :
385
+ Base(base) { }
386
+ };
387
+
388
+ private:
389
+ /// Internal pointer type permits fast address arithmetic
390
+ using BytePointer = char *;
391
+
392
+ private:
393
+ //
394
+ // Data members
395
+ //
396
+
397
+ UnderlyingPredicates the_predicates;
398
+
399
+ /// Parameters object with precomputed internal state
400
+ Params params_;
401
+
402
+ /// Internal pointer to first access of tile
403
+ BytePointer pointer_;
404
+
405
+ /// Used for out-of-order visitation
406
+ bool is_residue_tile_;
407
+
408
+ /// Below is used when Gather is turned on. We need to record strided_offset
409
+ /// and contiguous_offset separated to compute the offset by using
410
+ ///
411
+ /// offset = contiguous_offset + indices[strided_offset]
412
+
413
+ /// Gather indices
414
+ int const *indices_;
415
+
416
+ /// Function to perform layout permutation and offset computation
417
+ PermuteLayout permute_layout_;
418
+
419
+ /// Tracks thread's coordinate offset in the matrix for current tile.
420
+ /// This is only used in the following cases:
421
+ /// - when Gather is true, strided coordinate needed to access indices (contiguous offset is tracked via pointer_)
422
+ /// - when Permute is true, both coordinates are needed as input into permutation function (pointer_ is fixed)
423
+ TensorCoord coord_offset_;
424
+
425
+ private:
426
+ /// Computes predicates based on internally tracked per-thread offset.
427
+ CUTLASS_DEVICE
428
+ void compute_predicates_(
429
+ /// Extent of the matrix window
430
+ TensorCoord extent,
431
+ /// optionally, simplify predicate calculation during 'steady state' phase
432
+ bool is_steady_state = false) {
433
+ the_predicates.compute_predicates_(extent, is_steady_state);
434
+ }
435
+
436
+ public:
437
+
438
+ /// Default constructor
439
+ PredicatedTileAccessIterator() = default;
440
+
441
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
442
+ /// and thread ID
443
+ CUTLASS_HOST_DEVICE
444
+ PredicatedTileAccessIterator(
445
+ /// Precomputed parameters object
446
+ Params const &params,
447
+ /// Pointer to start of tensor
448
+ Pointer pointer,
449
+ /// Extent of tensor
450
+ TensorCoord extent,
451
+ /// ID of each participating thread
452
+ int thread_id,
453
+ /// Initial offset of threadblock
454
+ TensorCoord const &threadblock_offset,
455
+ /// Gather indices
456
+ int const *indices = nullptr)
457
+ : params_(params),
458
+ pointer_(reinterpret_cast<BytePointer>(
459
+ const_cast<NonConstPointer>(pointer))),
460
+ the_predicates(extent),
461
+ is_residue_tile_(true),
462
+ indices_(indices),
463
+ permute_layout_(TensorCoord(extent.contiguous(), extent.strided()), params.stride_) {
464
+
465
+ the_predicates.set_predicates(thread_id, threadblock_offset);
466
+
467
+ if (Gather) {
468
+ assert(indices_);
469
+ }
470
+
471
+ // update internal pointers
472
+ Layout layout(params_.stride_);
473
+
474
+ if (!Gather && !Permute) {
475
+ add_pointer_offset(layout(the_predicates.thread_offset_));
476
+ } else {
477
+ coord_offset_ = the_predicates.thread_offset_;
478
+ if (!Permute) {
479
+ add_pointer_offset(layout(make_Coord(coord_offset_.contiguous(), 0)));
480
+ }
481
+ }
482
+ }
483
+
484
+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset
485
+ CUTLASS_HOST_DEVICE
486
+ PredicatedTileAccessIterator(
487
+ /// Precomputed parameters object
488
+ Params const &params,
489
+ /// Pointer to start of tensor
490
+ Pointer pointer,
491
+ /// Extent of tensor
492
+ TensorCoord extent,
493
+ ///< ID of each participating thread
494
+ int thread_id)
495
+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id,
496
+ make_Coord(0, 0)) {}
497
+
498
+ /// Overrides the internal iteration index
499
+ CUTLASS_HOST_DEVICE
500
+ void set_iteration_index(int index) {
501
+ the_predicates.set_iteration_index(index);
502
+ }
503
+
504
+ /// Adds a pointer offset in units of Element
505
+ CUTLASS_HOST_DEVICE
506
+ void add_pointer_offset(LongIndex pointer_offset) {
507
+ pointer_ += sizeof_bits<Element>::value * pointer_offset / 8;
508
+ }
509
+
510
+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles
511
+ CUTLASS_DEVICE
512
+ void add_tile_offset(
513
+ TensorCoord const &tile_offset) {
514
+ if (is_residue_tile_) {
515
+
516
+ the_predicates.thread_offset_ += the_predicates.residue_offset_;
517
+
518
+ the_predicates.compute_predicates_(the_predicates.extent_, true);
519
+
520
+ Layout layout(params_.stride_);
521
+
522
+ if (!Gather && !Permute) {
523
+ add_pointer_offset(layout(the_predicates.residue_offset_));
524
+
525
+ if (kAdvanceRank) {
526
+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1);
527
+ pointer_ += Shape::kContiguous * tile_offset.contiguous() * sizeof_bits<Element>::value / 8;
528
+ } else {
529
+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1);
530
+ pointer_ += Shape::kStrided * tile_offset.strided() * sizeof_bits<Element>::value / 8;
531
+ }
532
+ } else {
533
+ coord_offset_.strided() = the_predicates.thread_offset_.strided() + Shape::kStrided * (tile_offset.strided() - kAdvanceRank);
534
+ if (!Permute) {
535
+ add_pointer_offset(layout(make_Coord(the_predicates.residue_offset_.contiguous(), 0)));
536
+ add_pointer_offset(Shape::kContiguous * (tile_offset.contiguous() - (1 - kAdvanceRank)));
537
+ } else {
538
+ coord_offset_.contiguous() = the_predicates.thread_offset_.contiguous() + Shape::kContiguous * (tile_offset.contiguous() - (1 - kAdvanceRank));
539
+ }
540
+ }
541
+ } else {
542
+ if (!Gather && !Permute) {
543
+ if (kAdvanceRank) {
544
+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided());
545
+ pointer_ += Shape::kContiguous * tile_offset.contiguous();
546
+ } else {
547
+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous());
548
+ pointer_ += Shape::kStrided * tile_offset.strided();
549
+ }
550
+ } else {
551
+ coord_offset_.strided() += Shape::kStrided * tile_offset.strided();
552
+ if (!Permute) {
553
+ add_pointer_offset(Shape::kContiguous * tile_offset.contiguous());
554
+ } else {
555
+ coord_offset_.contiguous() += Shape::kContiguous * tile_offset.contiguous();
556
+ }
557
+ }
558
+ }
559
+
560
+ is_residue_tile_ = false;
561
+ }
562
+
563
+ /// Returns a pointer
564
+ CUTLASS_HOST_DEVICE
565
+ AccessType *get() const {
566
+
567
+ if (Gather || Permute)
568
+ {
569
+ if (!valid()) {
570
+ return nullptr;
571
+ }
572
+
573
+ Index coord_contig = (Permute ? coord_offset_.contiguous() : 0) + the_predicates.iteration_contiguous_ * ThreadMap::Delta::kContiguous + the_predicates.iteration_vector_ * AccessType::kElements;
574
+ Index coord_strided = coord_offset_.strided() + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided;
575
+ if (Gather) {
576
+ coord_strided = indices_[coord_strided];
577
+ }
578
+
579
+ LongIndex offset = Permute ? permute_layout_(TensorCoord(coord_contig, coord_strided)) : (coord_strided * LongIndex(params_.stride_) + coord_contig);
580
+ return reinterpret_cast<AccessType *>(pointer_ + OffsetBytes<Element>(offset));
581
+ }
582
+
583
+ return reinterpret_cast<AccessType *>(
584
+ pointer_ +
585
+ the_predicates.iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value) / 8) + the_predicates.iteration_vector_;
586
+ }
587
+
588
+ /// Increment and return an instance to self.
589
+ CUTLASS_HOST_DEVICE
590
+ PredicatedTileAccessIterator &operator++() {
591
+
592
+ the_predicates.operator++();
593
+
594
+ ++the_predicates.iteration_vector_;
595
+ if (the_predicates.iteration_vector_ < kAccessesPerVector) {
596
+ return *this;
597
+ }
598
+
599
+ the_predicates.iteration_vector_ = 0;
600
+ ++the_predicates.iteration_contiguous_;
601
+
602
+ if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
603
+ return *this;
604
+ }
605
+
606
+ // Enter here only if (iteration_contiguous_ == ThreadMap::Iteration::kContiguous)
607
+ the_predicates.iteration_contiguous_ = 0;
608
+ ++the_predicates.iteration_strided_;
609
+
610
+ if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) {
611
+ if (!Gather && !Permute) {
612
+ pointer_ += params_.inc_strided_;
613
+ }
614
+
615
+ return *this;
616
+ }
617
+
618
+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
619
+ // which means we enter the next tile.
620
+ the_predicates.iteration_strided_ = 0;
621
+
622
+ if (!Gather && !Permute) {
623
+ // advance to next tile
624
+ pointer_ += params_.inc_next_;
625
+
626
+ // now return to start tile - if the iterator is subsequently advanced, this
627
+ // subtraction as well as the subsequent integer addition are both elided by
628
+ // the compiler.
629
+ pointer_ -= params_.inc_advance_;
630
+ }
631
+
632
+ return *this;
633
+ }
634
+
635
+ /// Increment and return an instance to self.
636
+ CUTLASS_HOST_DEVICE
637
+ PredicatedTileAccessIterator operator++(int) {
638
+ PredicatedTileAccessIterator self(*this);
639
+ operator++();
640
+ return self;
641
+ }
642
+
643
+ /// Clears the predicate set efficiently
644
+ CUTLASS_HOST_DEVICE
645
+ void clear_mask(bool enable = true) {
646
+ the_predicates.clear_mask(enable);
647
+ }
648
+
649
+ /// Clears the predicate set efficiently
650
+ CUTLASS_HOST_DEVICE
651
+ void enable_mask() {
652
+ the_predicates.enable_mask();
653
+ }
654
+
655
+ /// Sets the predicate mask, overriding value stored in predicate iterator
656
+ CUTLASS_HOST_DEVICE
657
+ void set_mask(Mask const &mask) {
658
+ the_predicates.set_mask(mask);
659
+ }
660
+
661
+ /// Gets the mask
662
+ CUTLASS_HOST_DEVICE
663
+ void get_mask(Mask &mask) {
664
+ the_predicates.get_mask(mask);
665
+ }
666
+
667
+ /// Returns whether access is valid or not
668
+ CUTLASS_HOST_DEVICE
669
+ bool valid() const {
670
+ return the_predicates.valid();
671
+ }
672
+ };
673
+
674
+ ////////////////////////////////////////////////////////////////////////////////
675
+
676
+ /// Specialization of PredicatedTileAccessIterator for column-major data.
677
+ ///
678
+ /// Satisfies: ForwardTileIteratorConcept |
679
+ /// ReadableContiguousTileIteratorConcept |
680
+ /// WriteableContiguousTileIteratorConcept |
681
+ /// MaskedTileIteratorConcept
682
+ ///
683
+ template <typename Shape_, typename Element_, int AdvanceRank,
684
+ typename ThreadMap_, typename AccessType_, bool Gather,
685
+ typename PermuteLayout>
686
+ class PredicatedTileAccessIterator<Shape_, Element_, layout::ColumnMajor,
687
+ AdvanceRank, ThreadMap_, AccessType_, Gather,
688
+ PermuteLayout> {
689
+ public:
690
+ static_assert(
691
+ AdvanceRank == 0 || AdvanceRank == 1,
692
+ "Specialization for pitch-linear iterator may along advance along the "
693
+ "contiguous(rank=0) or strided(rank=1) dimension.");
694
+
695
+ using Shape = Shape_;
696
+ using Element = Element_;
697
+ using Layout = layout::ColumnMajor;
698
+ static int const kAdvanceRank = AdvanceRank;
699
+ using ThreadMap = ThreadMap_;
700
+ using AccessType = AccessType_;
701
+
702
+ using Index = typename Layout::Index;
703
+ using LongIndex = typename Layout::LongIndex;
704
+
705
+ using TensorRef = TensorRef<Element, Layout>;
706
+ using TensorView = TensorView<Element, Layout>;
707
+ using TensorCoord = typename Layout::TensorCoord;
708
+
709
+ using Pointer = Element *;
710
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
711
+
712
+ using UnderlyingIterator = PredicatedTileAccessIterator<
713
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
714
+ layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType,
715
+ Gather, PermuteLayout>;
716
+
717
+ /// Predicate vector stores mask to guard accesses
718
+ using Mask = typename UnderlyingIterator::Mask;
719
+
720
+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
721
+
722
+ /// Parameters object is precomputed state and is host-constructible
723
+ class Params {
724
+ private:
725
+ friend PredicatedTileAccessIterator;
726
+
727
+ /// Parameters object
728
+ typename UnderlyingIterator::Params params_;
729
+
730
+ public:
731
+
732
+ /// Default constructor
733
+ Params() = default;
734
+
735
+ /// Construct the Params object given a pitch-linear tensor's layout
736
+ CUTLASS_HOST_DEVICE
737
+ Params(Layout const &layout)
738
+ : params_(layout::PitchLinear(layout.stride(0))){};
739
+
740
+ /// Construct the Params object given a pitch-linear tensor's layout
741
+ CUTLASS_HOST_DEVICE
742
+ Params(typename UnderlyingIterator::Params::Base const &base)
743
+ : params_(base) {}
744
+ };
745
+
746
+ private:
747
+ //
748
+ // Data members
749
+ //
750
+
751
+ /// Underlying pitch-linear tile iterator
752
+ UnderlyingIterator iterator_;
753
+
754
+ public:
755
+
756
+ /// Default constructor
757
+ PredicatedTileAccessIterator() = default;
758
+
759
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
760
+ /// and thread ID
761
+ CUTLASS_HOST_DEVICE
762
+ PredicatedTileAccessIterator(
763
+ ///< Precomputed parameters object
764
+ Params const &params,
765
+ ///< Pointer to start of tensor
766
+ Pointer pointer,
767
+ ///< Extent of tensor
768
+ TensorCoord extent,
769
+ ///< ID of each participating thread
770
+ int thread_id,
771
+ ///< Initial offset of threadblock
772
+ TensorCoord const &threadblock_offset,
773
+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
774
+ )
775
+ : iterator_(params.params_, pointer,
776
+ layout::PitchLinearCoord(extent.row(), extent.column()),
777
+ thread_id,
778
+ layout::PitchLinearCoord(threadblock_offset.row(),
779
+ threadblock_offset.column()),
780
+ indices) {}
781
+
782
+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset
783
+ CUTLASS_HOST_DEVICE
784
+ PredicatedTileAccessIterator(
785
+ Params const &params, ///< Precomputed parameters object
786
+ Pointer pointer, ///< Pointer to start of tensor
787
+ TensorCoord extent, ///< Extent of tensor
788
+ int thread_id ///< ID of each participating thread
789
+ )
790
+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id,
791
+ make_Coord(0, 0)) {}
792
+
793
+ /// Overrides the internal iteration index
794
+ CUTLASS_HOST_DEVICE
795
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
796
+
797
+ /// Adds a pointer offset in units of Element
798
+ CUTLASS_HOST_DEVICE
799
+ void add_pointer_offset(LongIndex pointer_offset) {
800
+ iterator_.add_pointer_offset(pointer_offset);
801
+ }
802
+
803
+ /// Advances an iterator along logical dimensions of matrix in units of whole
804
+ /// tiles
805
+ CUTLASS_HOST_DEVICE
806
+ void add_tile_offset(TensorCoord const &tile_offset) {
807
+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
808
+ }
809
+
810
+ /// Returns a pointer
811
+ CUTLASS_HOST_DEVICE
812
+ AccessType *get() const {
813
+ return reinterpret_cast<AccessType *>(iterator_.get());
814
+ }
815
+
816
+ /// Advances to the next tile in memory.
817
+ ///
818
+ /// The first time this method is called, predicates are updated, and the
819
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
820
+ /// Subsequent calls are lightweight and must only update the internal
821
+ /// pointer.
822
+ CUTLASS_HOST_DEVICE
823
+ PredicatedTileAccessIterator &operator++() {
824
+ ++iterator_;
825
+ return *this;
826
+ }
827
+
828
+ /// Advances to the next tile in memory.
829
+ ///
830
+ /// The first time this method is called, predicates are updated, and the
831
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
832
+ /// Subsequent calls are lightweight and must only update the internal
833
+ /// pointer.
834
+ CUTLASS_HOST_DEVICE
835
+ PredicatedTileAccessIterator operator++(int) {
836
+ PredicatedTileAccessIterator self(*this);
837
+ operator++();
838
+ return self;
839
+ }
840
+
841
+ /// Clears the predicate set efficiently
842
+ CUTLASS_HOST_DEVICE
843
+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
844
+
845
+ /// Clears the predicate set efficiently
846
+ CUTLASS_HOST_DEVICE
847
+ void enable_mask() { iterator_.enable_mask(); }
848
+
849
+ /// Sets the predicate mask, overriding value stored in predicate iterator
850
+ CUTLASS_HOST_DEVICE
851
+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
852
+
853
+ /// Gets the mask
854
+ CUTLASS_HOST_DEVICE
855
+ void get_mask(Mask &mask) { iterator_.get_mask(mask); }
856
+
857
+ /// Returns whether access is valid or not
858
+ CUTLASS_HOST_DEVICE
859
+ bool valid() {
860
+ return iterator_.valid();
861
+ }
862
+ };
863
+
864
+ ////////////////////////////////////////////////////////////////////////////////
865
+
866
+ /// Specialization of PredicatedTileAccessIterator for row-major data.
867
+ ///
868
+ /// Satisfies: ForwardTileIteratorConcept |
869
+ /// ReadableContiguousTileIteratorConcept |
870
+ /// WriteableContiguousTileIteratorConcept |
871
+ /// MaskedTileIteratorConcept
872
+ ///
873
+ template <typename Shape_, typename Element_, int AdvanceRank,
874
+ typename ThreadMap_, typename AccessType_, bool Gather,
875
+ typename PermuteLayout>
876
+ class PredicatedTileAccessIterator<Shape_, Element_, layout::RowMajor,
877
+ AdvanceRank, ThreadMap_, AccessType_, Gather,
878
+ PermuteLayout> {
879
+ public:
880
+ static_assert(
881
+ AdvanceRank == 0 || AdvanceRank == 1,
882
+ "Specialization for pitch-linear iterator may along advance along the "
883
+ "contiguous(rank=0) or strided(rank=1) dimension.");
884
+
885
+ using Shape = Shape_;
886
+ using Element = Element_;
887
+ using Layout = layout::RowMajor;
888
+ static int const kAdvanceRank = AdvanceRank;
889
+ using ThreadMap = ThreadMap_;
890
+ using AccessType = AccessType_;
891
+
892
+ using Index = typename Layout::Index;
893
+ using LongIndex = typename Layout::LongIndex;
894
+
895
+ using TensorRef = TensorRef<Element, Layout>;
896
+ using TensorView = TensorView<Element, Layout>;
897
+ using TensorCoord = typename Layout::TensorCoord;
898
+
899
+ using Pointer = Element *;
900
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
901
+
902
+ using UnderlyingIterator = PredicatedTileAccessIterator<
903
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
904
+ layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType,
905
+ Gather, PermuteLayout>;
906
+
907
+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
908
+
909
+ /// Predicate vector stores mask to guard accesses
910
+ using Mask = typename UnderlyingIterator::Mask;
911
+
912
+ /// Parameters object is precomputed state and is host-constructible
913
+ class Params {
914
+ private:
915
+ friend PredicatedTileAccessIterator;
916
+
917
+ /// Parameters object
918
+ typename UnderlyingIterator::Params params_;
919
+
920
+ public:
921
+
922
+ /// Default constructor
923
+ Params() = default;
924
+
925
+ /// Construct the Params object given a pitch-linear tensor's layout
926
+ CUTLASS_HOST_DEVICE
927
+ Params(Layout const &layout)
928
+ : params_(layout::PitchLinear(layout.stride(0))){};
929
+
930
+ /// Construct the Params object given a pitch-linear tensor's layout
931
+ CUTLASS_HOST_DEVICE
932
+ Params(typename UnderlyingIterator::Params::Base const &base)
933
+ : params_(base) {}
934
+ };
935
+
936
+ private:
937
+ //
938
+ // Data members
939
+ //
940
+
941
+ /// Underlying pitch-linear tile iterator
942
+ UnderlyingIterator iterator_;
943
+
944
+ public:
945
+
946
+ /// Default constructor
947
+ PredicatedTileAccessIterator() = default;
948
+
949
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
950
+ /// and thread ID
951
+ CUTLASS_HOST_DEVICE
952
+ PredicatedTileAccessIterator(
953
+ ///< Precomputed parameters object
954
+ Params const &params,
955
+ ///< Pointer to start of tensor
956
+ Pointer pointer,
957
+ ///< Extent of tensor
958
+ TensorCoord extent,
959
+ ///< ID of each participating thread
960
+ int thread_id,
961
+ ///< Initial offset of threadblock
962
+ TensorCoord const &threadblock_offset,
963
+ /// Gather indices
964
+ int const *indices = nullptr)
965
+ : iterator_(params.params_, pointer,
966
+ layout::PitchLinearCoord(extent.column(), extent.row()),
967
+ thread_id,
968
+ layout::PitchLinearCoord(threadblock_offset.column(),
969
+ threadblock_offset.row()),
970
+ indices) {}
971
+
972
+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset
973
+ CUTLASS_HOST_DEVICE
974
+ PredicatedTileAccessIterator(
975
+ Params const &params, ///< Precomputed parameters object
976
+ Pointer pointer, ///< Pointer to start of tensor
977
+ TensorCoord extent, ///< Extent of tensor
978
+ int thread_id ///< ID of each participating thread
979
+ )
980
+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id,
981
+ make_Coord(0, 0)) {}
982
+
983
+ /// Overrides the internal iteration index
984
+ CUTLASS_HOST_DEVICE
985
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
986
+
987
+ /// Adds a pointer offset in units of Element
988
+ CUTLASS_HOST_DEVICE
989
+ void add_pointer_offset(LongIndex pointer_offset) {
990
+ iterator_.add_pointer_offset(pointer_offset);
991
+ }
992
+
993
+ /// Advances an iterator along logical dimensions of matrix in units of whole
994
+ /// tiles
995
+ CUTLASS_HOST_DEVICE
996
+ void add_tile_offset(TensorCoord const &tile_offset) {
997
+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
998
+ }
999
+
1000
+ /// Returns a pointer
1001
+ CUTLASS_HOST_DEVICE
1002
+ AccessType *get() const {
1003
+ return reinterpret_cast<AccessType *>(iterator_.get());
1004
+ }
1005
+
1006
+ /// Advances to the next tile in memory.
1007
+ ///
1008
+ /// The first time this method is called, predicates are updated, and the
1009
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1010
+ /// Subsequent calls are lightweight and must only update the internal
1011
+ /// pointer.
1012
+ CUTLASS_HOST_DEVICE
1013
+ PredicatedTileAccessIterator &operator++() {
1014
+ ++iterator_;
1015
+ return *this;
1016
+ }
1017
+
1018
+ /// Advances to the next tile in memory.
1019
+ ///
1020
+ /// The first time this method is called, predicates are updated, and the
1021
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1022
+ /// Subsequent calls are lightweight and must only update the internal
1023
+ /// pointer.
1024
+ CUTLASS_HOST_DEVICE
1025
+ PredicatedTileAccessIterator operator++(int) {
1026
+ PredicatedTileAccessIterator self(*this);
1027
+ operator++();
1028
+ return self;
1029
+ }
1030
+
1031
+ /// Clears the predicate set efficiently
1032
+ CUTLASS_HOST_DEVICE
1033
+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
1034
+
1035
+ /// Clears the predicate set efficiently
1036
+ CUTLASS_HOST_DEVICE
1037
+ void enable_mask() { iterator_.enable_mask(); }
1038
+
1039
+ /// Sets the predicate mask, overriding value stored in predicate iterator
1040
+ CUTLASS_HOST_DEVICE
1041
+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
1042
+
1043
+ /// Gets the mask
1044
+ CUTLASS_HOST_DEVICE
1045
+ void get_mask(Mask &mask) { iterator_.get_mask(mask); }
1046
+
1047
+ /// Returns whether access is valid or not
1048
+ CUTLASS_HOST_DEVICE
1049
+ bool valid() {
1050
+ return iterator_.valid();
1051
+ }
1052
+ };
1053
+
1054
+ ////////////////////////////////////////////////////////////////////////////////
1055
+
1056
+ /// Specialization of PredicatedTileAccessIterator for affine rank 2 data.
1057
+ ///
1058
+ /// Satisfies: ForwardTileIteratorConcept |
1059
+ /// ReadableContiguousTileIteratorConcept |
1060
+ /// WriteableContiguousTileIteratorConcept |
1061
+ /// MaskedTileIteratorConcept
1062
+ ///
1063
+ template <typename Shape_, typename Element_, int AdvanceRank,
1064
+ typename ThreadMap_, typename AccessType_>
1065
+ class PredicatedTileAccessIterator<Shape_, Element_, layout::AffineRankN<2>,
1066
+ AdvanceRank, ThreadMap_, AccessType_, false,
1067
+ layout::NoPermute> {
1068
+ public:
1069
+ static_assert(
1070
+ AdvanceRank == 0 || AdvanceRank == 1,
1071
+ "Specialization for pitch-linear iterator may along advance along the "
1072
+ "contiguous(rank=0) or strided(rank=1) dimension.");
1073
+
1074
+ using Shape = Shape_;
1075
+ using Element = Element_;
1076
+ using Layout = layout::AffineRankN<2>;
1077
+ static int const kAdvanceRank = AdvanceRank;
1078
+ using ThreadMap = ThreadMap_;
1079
+ using AccessType = AccessType_;
1080
+
1081
+ using Index = typename Layout::Index;
1082
+ using LongIndex = typename Layout::LongIndex;
1083
+
1084
+ using TensorRef = TensorRef<Element, Layout>;
1085
+ using TensorView = TensorView<Element, Layout>;
1086
+ using TensorCoord = typename Layout::TensorCoord;
1087
+
1088
+ using Pointer = Element *;
1089
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
1090
+
1091
+ using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates<
1092
+ Shape, Element, layout::PitchLinear, AdvanceRank, ThreadMap, AccessType>;
1093
+
1094
+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
1095
+
1096
+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
1097
+ "Vectors implied by the thread map must be divisible by the access type.");
1098
+
1099
+ /// Predicate vector stores mask to guard accesses
1100
+ using Mask = typename UnderlyingPredicates::Mask;
1101
+
1102
+ /// Parameters object is precomputed state and is host-constructible
1103
+ class Params {
1104
+ public:
1105
+ friend PredicatedTileAccessIterator;
1106
+
1107
+ private:
1108
+ /// stride of pitch-linear layout (units of Element)
1109
+ Coord<Layout::kStrideRank, Layout::LongIndex> stride_;
1110
+ /// amount (in byte) to increment pointer to move to next access along
1111
+ /// contiguous dimension
1112
+ LongIndex inc_contiguous_;
1113
+ /// amount (in byte) to increment pointer from first access of current
1114
+ /// contiguous dimension to first access of next one.
1115
+ LongIndex inc_strided_;
1116
+ /// amount (in byte) to increment pointer from last access of current
1117
+ /// contiguous dimension to first access of next one.
1118
+ LongIndex inc_next_strided_;
1119
+ /// amount (in byte) to increment pointer from last access to first access
1120
+ /// of next tile
1121
+ LongIndex inc_next_;
1122
+ /// amount (in byte) to increment pointer from first access of current tile
1123
+ /// to first access of next tile
1124
+ LongIndex inc_advance_;
1125
+
1126
+ public:
1127
+
1128
+ // Default ctor
1129
+ CUTLASS_HOST_DEVICE
1130
+ Params(): stride_(0), inc_contiguous_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { }
1131
+
1132
+ /// Construct the Params object given a pitch-linear tensor's layout
1133
+ CUTLASS_HOST_DEVICE
1134
+ Params(Layout const &layout) : stride_({layout.stride(0), layout.stride(1)}) {
1135
+ inc_contiguous_ = (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) *
1136
+ sizeof_bits<Element>::value / 8;
1137
+
1138
+ inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) *
1139
+ sizeof_bits<Element>::value / 8;
1140
+
1141
+ inc_next_strided_ = inc_strided_ - LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_;
1142
+
1143
+ if (kAdvanceRank) {
1144
+ // advance along strided dimension
1145
+ inc_advance_ =
1146
+ Shape::kStrided * LongIndex(stride_[1]) * sizeof_bits<Element>::value / 8;
1147
+ } else {
1148
+ // advance along contiguous dimension
1149
+ inc_advance_ = Shape::kContiguous * stride_[0] * sizeof_bits<Element>::value / 8;
1150
+ }
1151
+
1152
+ inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_;
1153
+ };
1154
+ };
1155
+
1156
+ private:
1157
+ /// Internal pointer type permits fast address arithmetic
1158
+ using BytePointer = char *;
1159
+
1160
+ //
1161
+ // Data members
1162
+ //
1163
+
1164
+ /// Parameters object with precomputed internal state
1165
+ Params params_;
1166
+
1167
+ /// Internal pointer to first access of tile
1168
+ BytePointer pointer_;
1169
+
1170
+ UnderlyingPredicates the_predicates;
1171
+
1172
+ /// Used for out-of-order visitation
1173
+ bool is_residue_tile_;
1174
+
1175
+ private:
1176
+ /// Computes predicates based on internally tracked per-thread offset.
1177
+ CUTLASS_DEVICE
1178
+ void compute_predicates_(
1179
+ /// Extent of the matrix window
1180
+ TensorCoord extent,
1181
+ /// optionally, simplify predicate calculation during 'steady state' phase
1182
+ bool is_steady_state = false) {
1183
+ the_predicates.compute_predicates_(extent, is_steady_state);
1184
+ }
1185
+
1186
+ public:
1187
+
1188
+ /// Default constructor
1189
+ PredicatedTileAccessIterator() = default;
1190
+
1191
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
1192
+ /// and thread ID
1193
+ CUTLASS_HOST_DEVICE
1194
+ PredicatedTileAccessIterator(
1195
+ ///< Precomputed parameters object
1196
+ Params const &params,
1197
+ ///< Pointer to start of tensor
1198
+ Pointer pointer,
1199
+ ///< Extent of tensor
1200
+ TensorCoord extent,
1201
+ ///< ID of each participating thread
1202
+ int thread_id,
1203
+ ///< Initial offset of threadblock
1204
+ TensorCoord const &threadblock_offset,
1205
+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
1206
+ )
1207
+ : params_(params),
1208
+ pointer_(reinterpret_cast<BytePointer>(
1209
+ const_cast<NonConstPointer>(pointer))),
1210
+ the_predicates(extent),
1211
+ is_residue_tile_(true) {
1212
+
1213
+ the_predicates.set_predicates(thread_id, threadblock_offset);
1214
+
1215
+ // update internal pointers
1216
+ Layout layout(params_.stride_);
1217
+ add_pointer_offset(layout(the_predicates.thread_offset_));
1218
+ }
1219
+
1220
+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset
1221
+ CUTLASS_HOST_DEVICE
1222
+ PredicatedTileAccessIterator(
1223
+ Params const &params, ///< Precomputed parameters object
1224
+ Pointer pointer, ///< Pointer to start of tensor
1225
+ TensorCoord extent, ///< Extent of tensor
1226
+ int thread_id ///< ID of each participating thread
1227
+ )
1228
+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id,
1229
+ make_Coord(0, 0)) {}
1230
+
1231
+ /// Overrides the internal iteration index
1232
+ CUTLASS_HOST_DEVICE
1233
+ void set_iteration_index(int index) { the_predicates.set_iteration_index(index); }
1234
+
1235
+ /// Adds a pointer offset in units of Element
1236
+ CUTLASS_HOST_DEVICE
1237
+ void add_pointer_offset(LongIndex pointer_offset) {
1238
+ pointer_ += sizeof_bits<Element>::value * pointer_offset / 8;
1239
+ }
1240
+
1241
+ /// Advances an iterator along logical dimensions of matrix in units of whole
1242
+ /// tiles
1243
+ CUTLASS_HOST_DEVICE
1244
+ void add_tile_offset(TensorCoord const &tile_offset) {
1245
+ if (is_residue_tile_) {
1246
+
1247
+ the_predicates.thread_offset_ += the_predicates.residue_offset_;
1248
+
1249
+ Layout layout(params_.stride_);
1250
+ add_pointer_offset(layout(the_predicates.residue_offset_));
1251
+
1252
+ the_predicates.compute_predicates_(the_predicates.extent_, true);
1253
+
1254
+ if (kAdvanceRank) {
1255
+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1] - 1);
1256
+ pointer_ += Shape::kContiguous * tile_offset[0];
1257
+ } else {
1258
+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0] - 1);
1259
+ pointer_ += Shape::kStrided * tile_offset[1];
1260
+ }
1261
+ } else {
1262
+ if (kAdvanceRank) {
1263
+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]);
1264
+ pointer_ += Shape::kContiguous * tile_offset[0];
1265
+ } else {
1266
+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]);
1267
+ pointer_ += Shape::kStrided * tile_offset[1];
1268
+ }
1269
+ }
1270
+ is_residue_tile_ = false;
1271
+ }
1272
+
1273
+ /// Returns a pointer
1274
+ CUTLASS_HOST_DEVICE
1275
+ AccessType *get() const {
1276
+ return reinterpret_cast<AccessType *>(pointer_) + the_predicates.iteration_vector_;
1277
+ }
1278
+
1279
+ /// Advances to the next tile in memory.
1280
+ ///
1281
+ /// The first time this method is called, predicates are updated, and the
1282
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1283
+ /// Subsequent calls are lightweight and must only update the internal
1284
+ /// pointer.
1285
+ CUTLASS_HOST_DEVICE
1286
+ PredicatedTileAccessIterator &operator++() {
1287
+ the_predicates.operator++();
1288
+ ++the_predicates.iteration_vector_;
1289
+ if (the_predicates.iteration_vector_ < kAccessesPerVector) {
1290
+ return *this;
1291
+ }
1292
+
1293
+ the_predicates.iteration_vector_ = 0;
1294
+ ++the_predicates.iteration_contiguous_;
1295
+
1296
+ if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
1297
+ pointer_ += params_.inc_contiguous_;
1298
+ return *this;
1299
+ }
1300
+
1301
+ // Enter here only if (iteration_contiguous_ ==
1302
+ // ThreadMap::Iteration::kContiguous)
1303
+ the_predicates.iteration_contiguous_ = 0;
1304
+ ++the_predicates.iteration_strided_;
1305
+
1306
+ if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) {
1307
+ pointer_ += params_.inc_next_strided_;
1308
+ return *this;
1309
+ }
1310
+
1311
+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
1312
+ // which means we enter the next tile.
1313
+ the_predicates.iteration_strided_ = 0;
1314
+
1315
+ // advance to next tile
1316
+ pointer_ += params_.inc_next_;
1317
+
1318
+ // now return to start tile - if the iterator is subsequently advanced, this
1319
+ // subtraction as well as the subsequent integer addition are both elided by
1320
+ // the compiler.
1321
+ pointer_ -= params_.inc_advance_;
1322
+
1323
+ return *this;
1324
+ }
1325
+
1326
+ /// Advances to the next tile in memory.
1327
+ ///
1328
+ /// The first time this method is called, predicates are updated, and the
1329
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1330
+ /// Subsequent calls are lightweight and must only update the internal
1331
+ /// pointer.
1332
+ CUTLASS_HOST_DEVICE
1333
+ PredicatedTileAccessIterator operator++(int) {
1334
+ PredicatedTileAccessIterator self(*this);
1335
+ operator++();
1336
+ return self;
1337
+ }
1338
+
1339
+ /// Clears the predicate set efficiently
1340
+ CUTLASS_HOST_DEVICE
1341
+ void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); }
1342
+
1343
+ /// Clears the predicate set efficiently
1344
+ CUTLASS_HOST_DEVICE
1345
+ void enable_mask() { the_predicates.enable_mask(); }
1346
+
1347
+ /// Sets the predicate mask, overriding value stored in predicate iterator
1348
+ CUTLASS_HOST_DEVICE
1349
+ void set_mask(Mask const &mask) { the_predicates.set_mask(mask); }
1350
+
1351
+ /// Gets the mask
1352
+ CUTLASS_HOST_DEVICE
1353
+ void get_mask(Mask &mask) { the_predicates.get_mask(mask); }
1354
+
1355
+ /// Returns whether access is valid or not
1356
+ CUTLASS_HOST_DEVICE
1357
+ bool valid() {
1358
+ return the_predicates.valid();
1359
+ }
1360
+ };
1361
+
1362
+ ////////////////////////////////////////////////////////////////////////////////
1363
+
1364
+ /// Specialization of PredicatedTileAccessIterator for affine rank 2 column-major data.
1365
+ ///
1366
+ /// Satisfies: ForwardTileIteratorConcept |
1367
+ /// ReadableContiguousTileIteratorConcept |
1368
+ /// WriteableContiguousTileIteratorConcept |
1369
+ /// MaskedTileIteratorConcept
1370
+ ///
1371
+ template <typename Shape_, typename Element_, int AdvanceRank,
1372
+ typename ThreadMap_, typename AccessType_>
1373
+ class PredicatedTileAccessIterator<Shape_, Element_, layout::AffineRank2ColumnMajor,
1374
+ AdvanceRank, ThreadMap_, AccessType_, false,
1375
+ layout::NoPermute> {
1376
+ public:
1377
+ static_assert(
1378
+ AdvanceRank == 0 || AdvanceRank == 1,
1379
+ "Specialization for pitch-linear iterator may along advance along the "
1380
+ "contiguous(rank=0) or strided(rank=1) dimension.");
1381
+
1382
+ using Shape = Shape_;
1383
+ using Element = Element_;
1384
+ using Layout = layout::AffineRank2ColumnMajor;
1385
+ static int const kAdvanceRank = AdvanceRank;
1386
+ using ThreadMap = ThreadMap_;
1387
+ using AccessType = AccessType_;
1388
+
1389
+ using Index = typename Layout::Index;
1390
+ using LongIndex = typename Layout::LongIndex;
1391
+
1392
+ using TensorRef = TensorRef<Element, Layout>;
1393
+ using TensorView = TensorView<Element, Layout>;
1394
+ using TensorCoord = typename Layout::TensorCoord;
1395
+
1396
+ using Pointer = Element *;
1397
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
1398
+
1399
+ // Map to the underlying AffineRankN<2> layout
1400
+ using UnderlyingIterator = PredicatedTileAccessIterator<
1401
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
1402
+ layout::AffineRankN<2>, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>;
1403
+
1404
+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
1405
+
1406
+ /// Predicate vector stores mask to guard accesses
1407
+ using Mask = typename UnderlyingIterator::Mask;
1408
+
1409
+ /// Parameters object is precomputed state and is host-constructible
1410
+ class Params {
1411
+ private:
1412
+ friend PredicatedTileAccessIterator;
1413
+
1414
+ /// Parameters object
1415
+ typename UnderlyingIterator::Params params_;
1416
+
1417
+ public:
1418
+
1419
+ /// Default constructor
1420
+ Params() = default;
1421
+
1422
+ /// Construct the Params object given an AffineRankN<2> tensor's layout
1423
+ CUTLASS_HOST_DEVICE
1424
+ Params(Layout const &layout)
1425
+ : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){};
1426
+ };
1427
+
1428
+ private:
1429
+ //
1430
+ // Data members
1431
+ //
1432
+
1433
+ /// Underlying AffineRankN<2> tile iterator
1434
+ UnderlyingIterator iterator_;
1435
+
1436
+ public:
1437
+
1438
+ /// Default constructor
1439
+ PredicatedTileAccessIterator() = default;
1440
+
1441
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
1442
+ /// and thread ID
1443
+ CUTLASS_HOST_DEVICE
1444
+ PredicatedTileAccessIterator(
1445
+ ///< Precomputed parameters object
1446
+ Params const &params,
1447
+ ///< Pointer to start of tensor
1448
+ Pointer pointer,
1449
+ ///< Extent of tensor
1450
+ TensorCoord extent,
1451
+ ///< ID of each participating thread
1452
+ int thread_id,
1453
+ ///< Initial offset of threadblock
1454
+ TensorCoord const &threadblock_offset,
1455
+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
1456
+ )
1457
+ : iterator_(params.params_, pointer,
1458
+ layout::PitchLinearCoord(extent.row(), extent.column()),
1459
+ thread_id,
1460
+ layout::PitchLinearCoord(threadblock_offset.row(),
1461
+ threadblock_offset.column())) {}
1462
+
1463
+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset
1464
+ CUTLASS_HOST_DEVICE
1465
+ PredicatedTileAccessIterator(
1466
+ Params const &params, ///< Precomputed parameters object
1467
+ Pointer pointer, ///< Pointer to start of tensor
1468
+ TensorCoord extent, ///< Extent of tensor
1469
+ int thread_id ///< ID of each participating thread
1470
+ )
1471
+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id,
1472
+ make_Coord(0, 0)) {}
1473
+
1474
+ /// Overrides the internal iteration index
1475
+ CUTLASS_HOST_DEVICE
1476
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
1477
+
1478
+ /// Adds a pointer offset in units of Element
1479
+ CUTLASS_HOST_DEVICE
1480
+ void add_pointer_offset(LongIndex pointer_offset) {
1481
+ iterator_.add_pointer_offset(pointer_offset);
1482
+ }
1483
+
1484
+ /// Advances an iterator along logical dimensions of matrix in units of whole
1485
+ /// tiles
1486
+ CUTLASS_HOST_DEVICE
1487
+ void add_tile_offset(TensorCoord const &tile_offset) {
1488
+ iterator_.add_tile_offset(make_Coord(tile_offset.row(), tile_offset.column()));
1489
+ }
1490
+
1491
+ /// Returns a pointer
1492
+ CUTLASS_HOST_DEVICE
1493
+ AccessType *get() const {
1494
+ return reinterpret_cast<AccessType *>(iterator_.get());
1495
+ }
1496
+
1497
+ /// Advances to the next tile in memory.
1498
+ ///
1499
+ /// The first time this method is called, predicates are updated, and the
1500
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1501
+ /// Subsequent calls are lightweight and must only update the internal
1502
+ /// pointer.
1503
+ CUTLASS_HOST_DEVICE
1504
+ PredicatedTileAccessIterator &operator++() {
1505
+ ++iterator_;
1506
+ return *this;
1507
+ }
1508
+
1509
+ /// Advances to the next tile in memory.
1510
+ ///
1511
+ /// The first time this method is called, predicates are updated, and the
1512
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1513
+ /// Subsequent calls are lightweight and must only update the internal
1514
+ /// pointer.
1515
+ CUTLASS_HOST_DEVICE
1516
+ PredicatedTileAccessIterator operator++(int) {
1517
+ PredicatedTileAccessIterator self(*this);
1518
+ operator++();
1519
+ return self;
1520
+ }
1521
+
1522
+ /// Clears the predicate set efficiently
1523
+ CUTLASS_HOST_DEVICE
1524
+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
1525
+
1526
+ /// Clears the predicate set efficiently
1527
+ CUTLASS_HOST_DEVICE
1528
+ void enable_mask() { iterator_.enable_mask(); }
1529
+
1530
+ /// Sets the predicate mask, overriding value stored in predicate iterator
1531
+ CUTLASS_HOST_DEVICE
1532
+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
1533
+
1534
+ /// Gets the mask
1535
+ CUTLASS_HOST_DEVICE
1536
+ void get_mask(Mask &mask) { iterator_.get_mask(mask); }
1537
+
1538
+ /// Returns whether access is valid or not
1539
+ CUTLASS_HOST_DEVICE
1540
+ bool valid() {
1541
+ return iterator_.valid();
1542
+ }
1543
+ };
1544
+
1545
+ ////////////////////////////////////////////////////////////////////////////////
1546
+
1547
+ /// Specialization of PredicatedTileAccessIterator for affine rank-2 row-major data.
1548
+ ///
1549
+ /// Satisfies: ForwardTileIteratorConcept |
1550
+ /// ReadableContiguousTileIteratorConcept |
1551
+ /// WriteableContiguousTileIteratorConcept |
1552
+ /// MaskedTileIteratorConcept
1553
+ ///
1554
+ template <typename Shape_, typename Element_, int AdvanceRank,
1555
+ typename ThreadMap_, typename AccessType_>
1556
+ class PredicatedTileAccessIterator<Shape_, Element_, layout::AffineRank2RowMajor,
1557
+ AdvanceRank, ThreadMap_, AccessType_, false,
1558
+ layout::NoPermute> {
1559
+ public:
1560
+ static_assert(
1561
+ AdvanceRank == 0 || AdvanceRank == 1,
1562
+ "Specialization for pitch-linear iterator may along advance along the "
1563
+ "contiguous(rank=0) or strided(rank=1) dimension.");
1564
+
1565
+ using Shape = Shape_;
1566
+ using Element = Element_;
1567
+ using Layout = layout::AffineRank2RowMajor;
1568
+ static int const kAdvanceRank = AdvanceRank;
1569
+ using ThreadMap = ThreadMap_;
1570
+ using AccessType = AccessType_;
1571
+
1572
+ using Index = typename Layout::Index;
1573
+ using LongIndex = typename Layout::LongIndex;
1574
+
1575
+ using TensorRef = TensorRef<Element, Layout>;
1576
+ using TensorView = TensorView<Element, Layout>;
1577
+ using TensorCoord = typename Layout::TensorCoord;
1578
+
1579
+ using Pointer = Element *;
1580
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
1581
+
1582
+ // Map to the underlying AffineRankN<2> layout
1583
+ using UnderlyingIterator = PredicatedTileAccessIterator<
1584
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
1585
+ layout::AffineRankN<2>, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>;
1586
+
1587
+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
1588
+
1589
+ /// Predicate vector stores mask to guard accesses
1590
+ using Mask = typename UnderlyingIterator::Mask;
1591
+
1592
+ /// Parameters object is precomputed state and is host-constructible
1593
+ class Params {
1594
+ private:
1595
+ friend PredicatedTileAccessIterator;
1596
+
1597
+ /// Parameters object
1598
+ typename UnderlyingIterator::Params params_;
1599
+
1600
+ public:
1601
+
1602
+ /// Default constructor
1603
+ Params() = default;
1604
+
1605
+ /// Construct the Params object given an AffineRankN<2> tensor's layout
1606
+ CUTLASS_HOST_DEVICE
1607
+ Params(Layout const &layout)
1608
+ : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){};
1609
+ };
1610
+
1611
+ private:
1612
+ //
1613
+ // Data members
1614
+ //
1615
+
1616
+ /// Underlying AffineRankN<2> tile iterator
1617
+ UnderlyingIterator iterator_;
1618
+
1619
+ public:
1620
+
1621
+ /// Default constructor
1622
+ PredicatedTileAccessIterator() = default;
1623
+
1624
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
1625
+ /// and thread ID
1626
+ CUTLASS_HOST_DEVICE
1627
+ PredicatedTileAccessIterator(
1628
+ ///< Precomputed parameters object
1629
+ Params const &params,
1630
+ ///< Pointer to start of tensor
1631
+ Pointer pointer,
1632
+ ///< Extent of tensor
1633
+ TensorCoord extent,
1634
+ ///< ID of each participating thread
1635
+ int thread_id,
1636
+ ///< Initial offset of threadblock
1637
+ TensorCoord const &threadblock_offset,
1638
+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
1639
+ )
1640
+ : iterator_(params.params_, pointer,
1641
+ layout::PitchLinearCoord(extent.column(), extent.row()),
1642
+ thread_id,
1643
+ layout::PitchLinearCoord(threadblock_offset.column(),
1644
+ threadblock_offset.row())) {}
1645
+
1646
+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset
1647
+ CUTLASS_HOST_DEVICE
1648
+ PredicatedTileAccessIterator(
1649
+ Params const &params, ///< Precomputed parameters object
1650
+ Pointer pointer, ///< Pointer to start of tensor
1651
+ TensorCoord extent, ///< Extent of tensor
1652
+ int thread_id ///< ID of each participating thread
1653
+ )
1654
+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id,
1655
+ make_Coord(0, 0)) {}
1656
+
1657
+ /// Overrides the internal iteration index
1658
+ CUTLASS_HOST_DEVICE
1659
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
1660
+
1661
+ /// Adds a pointer offset in units of Element
1662
+ CUTLASS_HOST_DEVICE
1663
+ void add_pointer_offset(LongIndex pointer_offset) {
1664
+ iterator_.add_pointer_offset(pointer_offset);
1665
+ }
1666
+
1667
+ /// Advances an iterator along logical dimensions of matrix in units of whole
1668
+ /// tiles
1669
+ CUTLASS_HOST_DEVICE
1670
+ void add_tile_offset(TensorCoord const &tile_offset) {
1671
+ iterator_.add_tile_offset(make_Coord(tile_offset.column(), tile_offset.row()));
1672
+ }
1673
+
1674
+ /// Returns a pointer
1675
+ CUTLASS_HOST_DEVICE
1676
+ AccessType *get() const {
1677
+ return reinterpret_cast<AccessType *>(iterator_.get());
1678
+ }
1679
+
1680
+ /// Advances to the next tile in memory.
1681
+ ///
1682
+ /// The first time this method is called, predicates are updated, and the
1683
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1684
+ /// Subsequent calls are lightweight and must only update the internal
1685
+ /// pointer.
1686
+ CUTLASS_HOST_DEVICE
1687
+ PredicatedTileAccessIterator &operator++() {
1688
+ ++iterator_;
1689
+ return *this;
1690
+ }
1691
+
1692
+ /// Advances to the next tile in memory.
1693
+ ///
1694
+ /// The first time this method is called, predicates are updated, and the
1695
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1696
+ /// Subsequent calls are lightweight and must only update the internal
1697
+ /// pointer.
1698
+ CUTLASS_HOST_DEVICE
1699
+ PredicatedTileAccessIterator operator++(int) {
1700
+ PredicatedTileAccessIterator self(*this);
1701
+ operator++();
1702
+ return self;
1703
+ }
1704
+
1705
+ /// Clears the predicate set efficiently
1706
+ CUTLASS_HOST_DEVICE
1707
+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
1708
+
1709
+ /// Clears the predicate set efficiently
1710
+ CUTLASS_HOST_DEVICE
1711
+ void enable_mask() { iterator_.enable_mask(); }
1712
+
1713
+ /// Sets the predicate mask, overriding value stored in predicate iterator
1714
+ CUTLASS_HOST_DEVICE
1715
+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
1716
+
1717
+ /// Gets the mask
1718
+ CUTLASS_HOST_DEVICE
1719
+ void get_mask(Mask &mask) { iterator_.get_mask(mask); }
1720
+
1721
+ /// Returns whether access is valid or not
1722
+ CUTLASS_HOST_DEVICE
1723
+ bool valid() {
1724
+ return iterator_.valid();
1725
+ }
1726
+ };
1727
+
1728
+ ////////////////////////////////////////////////////////////////////////////////
1729
+
1730
+ /// Specialization of PredicatedTileAccessIterator for column-major interleaved data.
1731
+ /// It is mapped to the congruous layout.
1732
+ ///
1733
+ /// Satisfies: ForwardTileIteratorConcept |
1734
+ /// ReadableContiguousTileIteratorConcept |
1735
+ /// WriteableContiguousTileIteratorConcept |
1736
+ /// MaskedTileIteratorConcept
1737
+ ///
1738
+
1739
+ template <typename Shape_, typename Element_, int AdvanceRank,
1740
+ typename ThreadMap_, typename AccessType_, int InterleavedK>
1741
+ class PredicatedTileAccessIterator<Shape_, Element_,
1742
+ layout::ColumnMajorInterleaved<InterleavedK>,
1743
+ AdvanceRank, ThreadMap_, AccessType_, false,
1744
+ layout::NoPermute> {
1745
+ public:
1746
+ static_assert(
1747
+ AdvanceRank == 0 || AdvanceRank == 1,
1748
+ "Specialization for pitch-linear iterator may along advance along the "
1749
+ "contiguous(rank=0) or strided(rank=1) dimension.");
1750
+
1751
+ using Shape = Shape_;
1752
+ using Element = Element_;
1753
+ static int const kInterleavedK = InterleavedK;
1754
+ using Layout = layout::ColumnMajorInterleaved<kInterleavedK>;
1755
+ static int const kAdvanceRank = AdvanceRank;
1756
+ using ThreadMap = ThreadMap_;
1757
+ using AccessType = AccessType_;
1758
+
1759
+ using Index = typename Layout::Index;
1760
+ using LongIndex = typename Layout::LongIndex;
1761
+
1762
+ using TensorRef = TensorRef<Element, Layout>;
1763
+ using TensorView = TensorView<Element, Layout>;
1764
+ using TensorCoord = typename Layout::TensorCoord;
1765
+
1766
+ using Pointer = Element *;
1767
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
1768
+
1769
+ using UnderlyingIterator = PredicatedTileAccessIterator<
1770
+ layout::PitchLinearShape<Shape::kRow * kInterleavedK,
1771
+ Shape::kColumn / kInterleavedK>,
1772
+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap,
1773
+ AccessType>;
1774
+
1775
+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
1776
+
1777
+ /// Predicate vector stores mask to guard accesses
1778
+ using Mask = typename UnderlyingIterator::Mask;
1779
+
1780
+ /// Parameters object is precomputed state and is host-constructible
1781
+ class Params {
1782
+ private:
1783
+ friend PredicatedTileAccessIterator;
1784
+
1785
+ /// Parameters object
1786
+ typename UnderlyingIterator::Params params_;
1787
+
1788
+ public:
1789
+
1790
+ /// Default constructor
1791
+ Params() = default;
1792
+
1793
+ /// Construct the Params object given a pitch-linear tensor's layout
1794
+ CUTLASS_HOST_DEVICE
1795
+ Params(Layout const &layout)
1796
+ : params_(layout::PitchLinear(layout.stride(0))) {}
1797
+
1798
+ CUTLASS_HOST_DEVICE
1799
+ Params(typename UnderlyingIterator::Params::Base const &base)
1800
+ : params_(base) {}
1801
+ };
1802
+
1803
+ private:
1804
+ //
1805
+ // Data members
1806
+ //
1807
+
1808
+ /// Underlying pitch-linear tile iterator
1809
+ UnderlyingIterator iterator_;
1810
+
1811
+ public:
1812
+
1813
+ /// Default constructor
1814
+ PredicatedTileAccessIterator() = default;
1815
+
1816
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
1817
+ /// and thread ID
1818
+ CUTLASS_HOST_DEVICE
1819
+ PredicatedTileAccessIterator(
1820
+ /// Precomputed parameters object
1821
+ Params const &params,
1822
+ /// Pointer to start of tensor
1823
+ Pointer pointer,
1824
+ /// Extent of tensor
1825
+ TensorCoord extent,
1826
+ /// ID of each participating thread
1827
+ int thread_id,
1828
+ /// Initial offset of threadblock
1829
+ TensorCoord const &threadblock_offset,
1830
+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
1831
+ )
1832
+ : iterator_(params.params_, pointer,
1833
+ layout::PitchLinearCoord(extent.row() * kInterleavedK,
1834
+ extent.column() / kInterleavedK),
1835
+ thread_id,
1836
+ layout::PitchLinearCoord(
1837
+ threadblock_offset.row() * kInterleavedK,
1838
+ threadblock_offset.column() / kInterleavedK)) {}
1839
+
1840
+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset
1841
+ CUTLASS_HOST_DEVICE
1842
+ PredicatedTileAccessIterator(
1843
+ Params const &params, ///< Precomputed parameters object
1844
+ Pointer pointer, ///< Pointer to start of tensor
1845
+ TensorCoord extent, ///< Extent of tensor
1846
+ int thread_id ///< ID of each participating thread
1847
+ )
1848
+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id,
1849
+ make_Coord(0, 0)) {}
1850
+
1851
+ /// Overrides the internal iteration index
1852
+ CUTLASS_HOST_DEVICE
1853
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
1854
+
1855
+ /// Adds a pointer offset in units of Element
1856
+ CUTLASS_HOST_DEVICE
1857
+ void add_pointer_offset(LongIndex pointer_offset) {
1858
+ iterator_.add_pointer_offset(pointer_offset);
1859
+ }
1860
+
1861
+ /// Advances an iterator along logical dimensions of matrix in units of whole
1862
+ /// tiles
1863
+ CUTLASS_HOST_DEVICE
1864
+ void add_tile_offset(TensorCoord const &tile_offset) {
1865
+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
1866
+ }
1867
+
1868
+ /// Returns a pointer
1869
+ CUTLASS_HOST_DEVICE
1870
+ AccessType *get() const {
1871
+ return reinterpret_cast<AccessType *>(iterator_.get());
1872
+ }
1873
+
1874
+ /// Advances to the next tile in memory.
1875
+ ///
1876
+ /// The first time this method is called, predicates are updated, and the
1877
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1878
+ /// Subsequent calls are lightweight and must only update the internal
1879
+ /// pointer.
1880
+ CUTLASS_HOST_DEVICE
1881
+ PredicatedTileAccessIterator &operator++() {
1882
+ ++iterator_;
1883
+ return *this;
1884
+ }
1885
+
1886
+ /// Advances to the next tile in memory.
1887
+ ///
1888
+ /// The first time this method is called, predicates are updated, and the
1889
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1890
+ /// Subsequent calls are lightweight and must only update the internal
1891
+ /// pointer.
1892
+ CUTLASS_HOST_DEVICE
1893
+ PredicatedTileAccessIterator operator++(int) {
1894
+ PredicatedTileAccessIterator self(*this);
1895
+ operator++();
1896
+ return self;
1897
+ }
1898
+
1899
+ /// Clears the predicate set efficiently
1900
+ CUTLASS_HOST_DEVICE
1901
+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
1902
+
1903
+ /// Clears the predicate set efficiently
1904
+ CUTLASS_HOST_DEVICE
1905
+ void enable_mask() { iterator_.enable_mask(); }
1906
+
1907
+ /// Sets the predicate mask, overriding value stored in predicate iterator
1908
+ CUTLASS_HOST_DEVICE
1909
+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
1910
+
1911
+ /// Gets the mask
1912
+ CUTLASS_HOST_DEVICE
1913
+ void get_mask(Mask &mask) { iterator_.get_mask(mask); }
1914
+
1915
+ /// Returns whether access is valid or not
1916
+ CUTLASS_HOST_DEVICE
1917
+ bool valid() { return iterator_.valid(); }
1918
+ };
1919
+
1920
+ ////////////////////////////////////////////////////////////////////////////////
1921
+
1922
+ /// Specialization of PredicatedTileAccessIterator for row-major interleaved data.
1923
+ // It is mapped to the congruous layout.
1924
+ ///
1925
+ /// Satisfies: ForwardTileIteratorConcept |
1926
+ /// ReadableContiguousTileIteratorConcept |
1927
+ /// WriteableContiguousTileIteratorConcept |
1928
+ /// MaskedTileIteratorConcept
1929
+ ///
1930
+ template <typename Shape_, typename Element_, int AdvanceRank,
1931
+ typename ThreadMap_, typename AccessType_, int InterleavedK>
1932
+ class PredicatedTileAccessIterator<Shape_, Element_,
1933
+ layout::RowMajorInterleaved<InterleavedK>,
1934
+ AdvanceRank, ThreadMap_, AccessType_, false,
1935
+ layout::NoPermute> {
1936
+ public:
1937
+ static_assert(
1938
+ AdvanceRank == 0 || AdvanceRank == 1,
1939
+ "Specialization for pitch-linear iterator may along advance along the "
1940
+ "contiguous(rank=0) or strided(rank=1) dimension.");
1941
+
1942
+ using Shape = Shape_;
1943
+ using Element = Element_;
1944
+ static int const kInterleavedK = InterleavedK;
1945
+ using Layout = layout::RowMajorInterleaved<kInterleavedK>;
1946
+ static int const kAdvanceRank = AdvanceRank;
1947
+ using ThreadMap = ThreadMap_;
1948
+ using AccessType = AccessType_;
1949
+
1950
+ using Index = typename Layout::Index;
1951
+ using LongIndex = typename Layout::LongIndex;
1952
+
1953
+ using TensorRef = TensorRef<Element, Layout>;
1954
+ using TensorView = TensorView<Element, Layout>;
1955
+ using TensorCoord = typename Layout::TensorCoord;
1956
+
1957
+ using Pointer = Element *;
1958
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
1959
+
1960
+ using UnderlyingIterator = PredicatedTileAccessIterator<
1961
+ layout::PitchLinearShape<Shape::kColumn * kInterleavedK,
1962
+ Shape::kRow / kInterleavedK>,
1963
+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap,
1964
+ AccessType>;
1965
+
1966
+
1967
+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
1968
+
1969
+ /// Predicate vector stores mask to guard accesses
1970
+ using Mask = typename UnderlyingIterator::Mask;
1971
+
1972
+ /// Parameters object is precomputed state and is host-constructible
1973
+ class Params {
1974
+ private:
1975
+ friend PredicatedTileAccessIterator;
1976
+
1977
+ /// Parameters object
1978
+ typename UnderlyingIterator::Params params_;
1979
+
1980
+ public:
1981
+
1982
+ /// Default constructor
1983
+ Params() = default;
1984
+
1985
+ /// Construct the Params object given a pitch-linear tensor's layout
1986
+ CUTLASS_HOST_DEVICE
1987
+ Params(Layout const &layout)
1988
+ : params_(layout::PitchLinear(layout.stride(0))) {}
1989
+
1990
+ CUTLASS_HOST_DEVICE
1991
+ Params(typename UnderlyingIterator::Params::Base const &base)
1992
+ : params_(base) {}
1993
+ };
1994
+
1995
+ private:
1996
+ //
1997
+ // Data members
1998
+ //
1999
+
2000
+ /// Underlying pitch-linear tile iterator
2001
+ UnderlyingIterator iterator_;
2002
+
2003
+ public:
2004
+
2005
+ /// Default constructor
2006
+ PredicatedTileAccessIterator() = default;
2007
+
2008
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
2009
+ /// and thread ID
2010
+ CUTLASS_HOST_DEVICE
2011
+ PredicatedTileAccessIterator(
2012
+ /// Precomputed parameters object
2013
+ Params const &params,
2014
+ /// Pointer to start of tensor
2015
+ Pointer pointer,
2016
+ /// Extent of tensor
2017
+ TensorCoord extent,
2018
+ /// ID of each participating thread
2019
+ int thread_id,
2020
+ /// Initial offset of threadblock
2021
+ TensorCoord const &threadblock_offset,
2022
+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
2023
+ )
2024
+ : iterator_(params.params_, pointer,
2025
+ layout::PitchLinearCoord(extent.column() * kInterleavedK,
2026
+ extent.row() / kInterleavedK),
2027
+ thread_id,
2028
+ layout::PitchLinearCoord(
2029
+ threadblock_offset.column() * kInterleavedK,
2030
+ threadblock_offset.row() / kInterleavedK)) {}
2031
+
2032
+ /// Construct a PredicatedTileAccessIterator with zero threadblock offset
2033
+ CUTLASS_HOST_DEVICE
2034
+ PredicatedTileAccessIterator(
2035
+ Params const &params, ///< Precomputed parameters object
2036
+ Pointer pointer, ///< Pointer to start of tensor
2037
+ TensorCoord extent, ///< Extent of tensor
2038
+ int thread_id ///< ID of each participating thread
2039
+ )
2040
+ : PredicatedTileAccessIterator(params, pointer, extent, thread_id,
2041
+ make_Coord(0, 0)) {}
2042
+
2043
+ /// Overrides the internal iteration index
2044
+ CUTLASS_HOST_DEVICE
2045
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
2046
+
2047
+ /// Adds a pointer offset in units of Element
2048
+ CUTLASS_HOST_DEVICE
2049
+ void add_pointer_offset(LongIndex pointer_offset) {
2050
+ iterator_.add_pointer_offset(pointer_offset);
2051
+ }
2052
+
2053
+ /// Advances an iterator along logical dimensions of matrix in units of whole
2054
+ /// tiles
2055
+ CUTLASS_HOST_DEVICE
2056
+ void add_tile_offset(TensorCoord const &tile_offset) {
2057
+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
2058
+ }
2059
+
2060
+ /// Returns a pointer
2061
+ CUTLASS_HOST_DEVICE
2062
+ AccessType *get() const {
2063
+ return reinterpret_cast<AccessType *>(iterator_.get());
2064
+ }
2065
+
2066
+ /// Advances to the next tile in memory.
2067
+ ///
2068
+ /// The first time this method is called, predicates are updated, and the
2069
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
2070
+ /// Subsequent calls are lightweight and must only update the internal
2071
+ /// pointer.
2072
+ CUTLASS_HOST_DEVICE
2073
+ PredicatedTileAccessIterator &operator++() {
2074
+ ++iterator_;
2075
+ return *this;
2076
+ }
2077
+
2078
+ /// Advances to the next tile in memory.
2079
+ ///
2080
+ /// The first time this method is called, predicates are updated, and the
2081
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
2082
+ /// Subsequent calls are lightweight and must only update the internal
2083
+ /// pointer.
2084
+ CUTLASS_HOST_DEVICE
2085
+ PredicatedTileAccessIterator operator++(int) {
2086
+ PredicatedTileAccessIterator self(*this);
2087
+ operator++();
2088
+ return self;
2089
+ }
2090
+
2091
+ /// Clears the predicate set efficiently
2092
+ CUTLASS_HOST_DEVICE
2093
+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
2094
+
2095
+ /// Clears the predicate set efficiently
2096
+ CUTLASS_HOST_DEVICE
2097
+ void enable_mask() { iterator_.enable_mask(); }
2098
+
2099
+ /// Sets the predicate mask, overriding value stored in predicate iterator
2100
+ CUTLASS_HOST_DEVICE
2101
+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
2102
+
2103
+ /// Gets the mask
2104
+ CUTLASS_HOST_DEVICE
2105
+ void get_mask(Mask &mask) { iterator_.get_mask(mask); }
2106
+
2107
+ /// Returns whether access is valid or not
2108
+ CUTLASS_HOST_DEVICE
2109
+ bool valid() { return iterator_.valid(); }
2110
+ };
2111
+
2112
+ ////////////////////////////////////////////////////////////////////////////////
2113
+
2114
+ } // namespace threadblock
2115
+ } // namespace transform
2116
+ } // namespace cutlass
2117
+
2118
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h ADDED
@@ -0,0 +1,834 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates calculating the address and predicates to the load of tiles
33
+ from pitch-linear rank=2 tensors.
34
+
35
+ This iterator uses masks to guard out-of-bounds accesses and visits the last
36
+ "residue" tile first, with the objective of minimizing predicate mask updates
37
+ during steady-state operation.
38
+
39
+ A precomputed "Params" object minimizes the amount of state that must be
40
+ stored in registers, and integer addition is used to advance the pointer
41
+ through memory.
42
+ */
43
+
44
+ #pragma once
45
+
46
+ #include "cutlass/array.h"
47
+ #include "cutlass/coord.h"
48
+ #include "cutlass/cutlass.h"
49
+ #include "cutlass/layout/matrix.h"
50
+ #include "cutlass/layout/pitch_linear.h"
51
+ #include "cutlass/matrix_shape.h"
52
+ #include "cutlass/predicate_vector.h"
53
+ #include "cutlass/tensor_ref.h"
54
+ #include "cutlass/tensor_view.h"
55
+ #include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h"
56
+
57
+ ////////////////////////////////////////////////////////////////////////////////
58
+
59
+ ////////////////////////////////////////////////////////////////////////////////
60
+
61
+ namespace cutlass {
62
+ namespace transform {
63
+ namespace threadblock {
64
+
65
+ ////////////////////////////////////////////////////////////////////////////////
66
+
67
+ /// PredicatedTileAccessIterator2dThreadTile
68
+ ///
69
+ template <typename Shape, typename Element, typename Layout, int AdvanceRank,
70
+ typename ThreadMap, typename AccessType>
71
+ class PredicatedTileAccessIterator2dThreadTile;
72
+
73
+ ////////////////////////////////////////////////////////////////////////////////
74
+
75
+ /// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data.
76
+ ///
77
+ template <typename Shape_, typename Element_, int AdvanceRank,
78
+ typename ThreadMap_, typename AccessType_>
79
+ class PredicatedTileAccessIterator2dThreadTile<Shape_, Element_, layout::PitchLinear,
80
+ AdvanceRank, ThreadMap_, AccessType_> {
81
+ public:
82
+ static_assert(
83
+ AdvanceRank == 0 || AdvanceRank == 1,
84
+ "Specialization for pitch-linear iterator may along advance along the "
85
+ "contiguous(rank=0) or strided(rank=1) dimension.");
86
+
87
+ using Shape = Shape_;
88
+ using Element = Element_;
89
+ using Layout = layout::PitchLinear;
90
+ static int const kAdvanceRank = AdvanceRank;
91
+ using ThreadMap = ThreadMap_;
92
+ using AccessType = AccessType_;
93
+
94
+ using Index = typename Layout::Index;
95
+ using LongIndex = typename Layout::LongIndex;
96
+ using StrideIndex = typename Layout::Stride::Index;
97
+
98
+ using TensorRef = TensorRef<Element, Layout>;
99
+ using TensorView = TensorView<Element, Layout>;
100
+ using TensorCoord = typename Layout::TensorCoord;
101
+
102
+ using Pointer = Element *;
103
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
104
+
105
+ static int const kPredicatesPerByte = 4;
106
+ static int const kPredicatesPerWord = 4 * kPredicatesPerByte;
107
+
108
+ /// Number of 32b words containing predicates
109
+ static int const kPredicateByteCount = (ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kStrided + kPredicatesPerByte - 1) / kPredicatesPerByte;
110
+ static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4;
111
+
112
+ static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u;
113
+
114
+ static_assert(kPredicateWordCount <= 4, "Too many predicates.");
115
+
116
+ /// Predicate vector stores mask to guard accesses
117
+ using Mask = Array<uint32_t, kPredicateWordCount>;
118
+
119
+ /// Uses a non-template class
120
+ struct Params : PredicatedTileAccessIteratorParams {
121
+
122
+ public:
123
+ friend PredicatedTileAccessIterator2dThreadTile;
124
+
125
+ using Base = PredicatedTileAccessIteratorParams;
126
+
127
+ // Default ctor
128
+ CUTLASS_HOST_DEVICE
129
+ Params() { }
130
+
131
+ /// Construct the Params object given a pitch-linear tensor's layout
132
+ CUTLASS_HOST_DEVICE
133
+ Params(Layout const &layout) :
134
+ Base(layout.stride(0),
135
+ MakePredicatedTileAccessIteratorDesc<Shape, Element, Layout, kAdvanceRank, ThreadMap>()()
136
+ ) { }
137
+
138
+ CUTLASS_HOST_DEVICE
139
+ Params(Base const &base) :
140
+ Base(base) { }
141
+ };
142
+
143
+
144
+ private:
145
+ /// Internal pointer type permits fast address arithmetic
146
+ using BytePointer = char *;
147
+
148
+ private:
149
+ //
150
+ // Data members
151
+ //
152
+
153
+ /// Parameters object with precomputed internal state
154
+ Params const &params_;
155
+
156
+ /// Internal pointer to first access of tile
157
+ BytePointer pointer_;
158
+
159
+ /// Guard predicates
160
+ uint32_t predicates_[kPredicateWordCount];
161
+
162
+ /// Size of tensor
163
+ TensorCoord extent_;
164
+
165
+ /// Initial offset for each thread
166
+ TensorCoord thread_offset_;
167
+
168
+ /// Index of residue tile
169
+ int residue_tile_idx_;
170
+
171
+ /// Used for out-of-order visitation
172
+ bool is_residue_tile_;
173
+
174
+ /// Iteration in the contiguous dimension
175
+ int iteration_contiguous_;
176
+
177
+ /// Iteration in the strided dimension
178
+ int iteration_strided_;
179
+
180
+ /// Tracks iterations within the thread loop
181
+ int iteration_thread_;
182
+
183
+ private:
184
+ /// Computes predicates based on internally tracked per-thread offset.
185
+ CUTLASS_HOST_DEVICE
186
+ void compute_predicates_(
187
+ /// optionally, simplify predicate calculation during 'steady state' phase
188
+ bool is_steady_state = false) {
189
+
190
+ CUTLASS_PRAGMA_UNROLL
191
+ for (int i = 0; i < kPredicateWordCount; ++i) {
192
+ predicates_[i] = 0u;
193
+ }
194
+
195
+ CUTLASS_PRAGMA_UNROLL
196
+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
197
+ CUTLASS_PRAGMA_UNROLL
198
+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
199
+ CUTLASS_PRAGMA_UNROLL
200
+ for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++) {
201
+
202
+ TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous,
203
+ ts + s * ThreadMap::Delta::kStrided);
204
+
205
+ TensorCoord coord = thread_offset_ + iteration_coord;
206
+
207
+ bool guard;
208
+
209
+ if (is_steady_state) {
210
+ if (kAdvanceRank == 0) {
211
+ guard = (coord.strided() < extent_.strided());
212
+ } else {
213
+ guard = (coord.contiguous() < extent_.contiguous());
214
+ }
215
+ } else {
216
+ guard = (coord.strided() < extent_.strided() &&
217
+ coord.contiguous() < extent_.contiguous());
218
+ }
219
+
220
+ int pred_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided;
221
+ int word_idx = pred_idx / kPredicatesPerWord;
222
+ int residual = pred_idx % kPredicatesPerWord;
223
+ int byte_idx = residual / kPredicatesPerByte;
224
+ int bit_idx = residual % kPredicatesPerByte;
225
+
226
+ predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx));
227
+
228
+ }
229
+ }
230
+ }
231
+
232
+ }
233
+
234
+ public:
235
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
236
+ /// and thread ID
237
+ CUTLASS_HOST_DEVICE
238
+ PredicatedTileAccessIterator2dThreadTile(
239
+ /// Precomputed parameters object
240
+ Params const &params,
241
+ /// Pointer to start of tensor
242
+ Pointer pointer,
243
+ /// Extent of tensor
244
+ TensorCoord extent,
245
+ /// ID of each participating thread
246
+ int thread_id,
247
+ /// Initial offset of threadblock
248
+ TensorCoord const &threadblock_offset)
249
+ : params_(params),
250
+ pointer_(reinterpret_cast<BytePointer>(
251
+ const_cast<NonConstPointer>(pointer))),
252
+ extent_(extent),
253
+ is_residue_tile_(true) {
254
+
255
+
256
+ TensorCoord residue_offset;
257
+ if (kAdvanceRank) {
258
+ residue_tile_idx_ =
259
+ (extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) /
260
+ Shape::kStrided;
261
+ residue_offset = make_Coord(0, residue_tile_idx_ * Shape::kStrided);
262
+ } else {
263
+ residue_tile_idx_ =
264
+ (extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) /
265
+ Shape::kContiguous;
266
+ residue_offset = make_Coord(residue_tile_idx_ * Shape::kContiguous, 0);
267
+ }
268
+
269
+ // Per-thread offset in logical coordinates of tensor
270
+ thread_offset_ = threadblock_offset + residue_offset +
271
+ ThreadMap::initial_offset(thread_id);
272
+
273
+ // update internal pointers
274
+ Layout layout(params_.stride_);
275
+ add_pointer_offset(layout(thread_offset_));
276
+
277
+ compute_predicates_(false);
278
+
279
+ set_iteration_index(0);
280
+ }
281
+
282
+ /// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset
283
+ CUTLASS_HOST_DEVICE
284
+ PredicatedTileAccessIterator2dThreadTile(
285
+ /// Precomputed parameters object
286
+ Params const &params,
287
+ /// Pointer to start of tensor
288
+ Pointer pointer,
289
+ /// Extent of tensor
290
+ TensorCoord extent,
291
+ ///< ID of each participating thread
292
+ int thread_id)
293
+ : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id,
294
+ make_Coord(0, 0)) {}
295
+
296
+ /// Overrides the internal iteration index
297
+ CUTLASS_HOST_DEVICE
298
+ void set_iteration_index(int index) {
299
+
300
+ int residual = index % (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided);
301
+ iteration_strided_ = index / (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided);
302
+
303
+ iteration_contiguous_ = residual / ThreadMap::ThreadAccessShape::kStrided;
304
+ iteration_thread_ = residual % ThreadMap::ThreadAccessShape::kStrided;
305
+
306
+ }
307
+
308
+ /// Adds a pointer offset in units of Element
309
+ CUTLASS_HOST_DEVICE
310
+ void add_pointer_offset(LongIndex pointer_offset) {
311
+ pointer_ += int(sizeof(Element)) * pointer_offset;
312
+ }
313
+
314
+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles
315
+ CUTLASS_DEVICE
316
+ void add_tile_offset(
317
+ TensorCoord const &tile_offset) {
318
+ if (is_residue_tile_) {
319
+ TensorCoord residue_offset;
320
+ if (kAdvanceRank) {
321
+ residue_offset = TensorCoord(0, residue_tile_idx_ * Shape::kStrided);
322
+ } else {
323
+ residue_offset = TensorCoord(residue_tile_idx_ * Shape::kContiguous, 0);
324
+ }
325
+
326
+ thread_offset_ -= residue_offset;
327
+
328
+ Layout layout(params_.stride_);
329
+ add_pointer_offset(-layout(residue_offset));
330
+
331
+ compute_predicates_(true);
332
+
333
+ if (kAdvanceRank) {
334
+ pointer_ += params_.inc_advance_ * (tile_offset.strided() - 1);
335
+ pointer_ += Shape::kContiguous * tile_offset.contiguous();
336
+ } else {
337
+ pointer_ += params_.inc_advance_ * (tile_offset.contiguous() - 1);
338
+ pointer_ += Shape::kStrided * tile_offset.strided();
339
+ }
340
+ } else {
341
+ if (kAdvanceRank) {
342
+ pointer_ += params_.inc_advance_ * tile_offset.strided();
343
+ pointer_ += Shape::kContiguous * tile_offset.contiguous();
344
+ } else {
345
+ pointer_ += params_.inc_advance_ * tile_offset.contiguous();
346
+ pointer_ += Shape::kStrided * tile_offset.strided();
347
+ }
348
+ }
349
+ is_residue_tile_ = false;
350
+ }
351
+
352
+ CUTLASS_HOST_DEVICE
353
+ AccessType *get() const {
354
+
355
+ AccessType *ret_val = reinterpret_cast<AccessType *>(
356
+ pointer_ + (iteration_thread_ * params_.stride_ + iteration_contiguous_ * ThreadMap::Delta::kContiguous) * int(sizeof(Element)));
357
+
358
+ return ret_val;
359
+ }
360
+
361
+ /// Increment and return an instance to self.
362
+ CUTLASS_HOST_DEVICE
363
+ PredicatedTileAccessIterator2dThreadTile &operator++() {
364
+
365
+ iteration_thread_++;
366
+
367
+ if (iteration_thread_ < ThreadMap::ThreadAccessShape::kStrided)
368
+ return *this;
369
+
370
+ iteration_thread_ = 0;
371
+
372
+ ++iteration_contiguous_;
373
+
374
+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
375
+ return *this;
376
+
377
+ // Enter here only if (iteration_contiguous_ ==
378
+ // ThreadMap::Iteration::kContiguous)
379
+ iteration_contiguous_ = 0;
380
+ ++iteration_strided_;
381
+
382
+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
383
+ pointer_ += params_.inc_strided_;
384
+ return *this;
385
+ }
386
+
387
+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
388
+ // which means we enter the next tile.
389
+ iteration_strided_ = 0;
390
+
391
+ // advance to next tile
392
+ pointer_ += params_.inc_next_;
393
+
394
+ // now return to start tile - if the iterator is subsequently advanced, this
395
+ // subtraction as well as the subsequent integer addition are both elided by
396
+ // the compiler.
397
+ pointer_ -= params_.inc_advance_;
398
+
399
+ return *this;
400
+ }
401
+
402
+ /// Increment and return an instance to self.
403
+ CUTLASS_HOST_DEVICE
404
+ PredicatedTileAccessIterator2dThreadTile operator++(int) {
405
+ PredicatedTileAccessIterator2dThreadTile self(*this);
406
+ operator++();
407
+ return self;
408
+ }
409
+
410
+ /// Clears the predicate set efficiently
411
+ CUTLASS_HOST_DEVICE
412
+ void clear_mask(bool enable = true) {
413
+ CUTLASS_PRAGMA_UNROLL
414
+ for (int i = 0; i < kPredicateWordCount; ++i) {
415
+ predicates_[i] = enable ? 0u : predicates_[i];
416
+ }
417
+
418
+ }
419
+
420
+ /// Clears the predicate set efficiently
421
+ CUTLASS_HOST_DEVICE
422
+ void enable_mask() {
423
+ CUTLASS_PRAGMA_UNROLL
424
+ for (int i = 0; i < kPredicateWordCount; ++i) {
425
+ predicates_[i] = 0xffffffff;
426
+ }
427
+ }
428
+
429
+ /// Sets the predicate mask, overriding value stored in predicate iterator
430
+ CUTLASS_HOST_DEVICE
431
+ void set_mask(Mask const &mask) {
432
+ CUTLASS_PRAGMA_UNROLL
433
+ for (int i = 0; i < kPredicateWordCount; ++i) {
434
+ predicates_[i] = mask[i];
435
+ }
436
+
437
+ }
438
+
439
+ /// Gets the mask
440
+ CUTLASS_HOST_DEVICE
441
+ void get_mask(Mask &mask) {
442
+ CUTLASS_PRAGMA_UNROLL
443
+ for (int i = 0; i < kPredicateWordCount; ++i) {
444
+ mask[i] = predicates_[i];
445
+ }
446
+ }
447
+
448
+ /// Returns whether access is valid or not
449
+ CUTLASS_HOST_DEVICE
450
+ bool valid() {
451
+
452
+ int pred_idx =
453
+ iteration_thread_ +
454
+ iteration_contiguous_ * ThreadMap::ThreadAccessShape::kStrided +
455
+ iteration_strided_ * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided;
456
+
457
+ int word_idx = pred_idx / kPredicatesPerWord;
458
+ int residual = pred_idx % kPredicatesPerWord;
459
+ int byte_idx = residual / kPredicatesPerByte;
460
+ int bit_idx = residual % kPredicatesPerByte;
461
+
462
+ bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0;
463
+
464
+ return pred;
465
+ }
466
+ };
467
+
468
+ ////////////////////////////////////////////////////////////////////////////////
469
+
470
+ /// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data.
471
+ ///
472
+ /// Satisfies: ForwardTileIteratorConcept |
473
+ /// ReadableContiguousTileIteratorConcept |
474
+ /// WriteableContiguousTileIteratorConcept |
475
+ /// MaskedTileIteratorConcept
476
+ ///
477
+ template <typename Shape_, typename Element_, int AdvanceRank,
478
+ typename ThreadMap_, typename AccessType_>
479
+ class PredicatedTileAccessIterator2dThreadTile<Shape_, Element_, layout::ColumnMajor,
480
+ AdvanceRank, ThreadMap_, AccessType_> {
481
+ public:
482
+ static_assert(
483
+ AdvanceRank == 0 || AdvanceRank == 1,
484
+ "Specialization for pitch-linear iterator may along advance along the "
485
+ "contiguous(rank=0) or strided(rank=1) dimension.");
486
+
487
+ using Shape = Shape_;
488
+ using Element = Element_;
489
+ using Layout = layout::ColumnMajor;
490
+ static int const kAdvanceRank = AdvanceRank;
491
+ using ThreadMap = ThreadMap_;
492
+ using AccessType = AccessType_;
493
+
494
+ using Index = typename Layout::Index;
495
+ using LongIndex = typename Layout::LongIndex;
496
+
497
+ using TensorRef = TensorRef<Element, Layout>;
498
+ using TensorView = TensorView<Element, Layout>;
499
+ using TensorCoord = typename Layout::TensorCoord;
500
+
501
+ using Pointer = Element *;
502
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
503
+
504
+ using UnderlyingIterator = PredicatedTileAccessIterator2dThreadTile<
505
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
506
+ layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>;
507
+
508
+ /// Predicate vector stores mask to guard accesses
509
+ using Mask = typename UnderlyingIterator::Mask;
510
+
511
+ /// Parameters object is precomputed state and is host-constructible
512
+ class Params {
513
+ private:
514
+ friend PredicatedTileAccessIterator2dThreadTile;
515
+
516
+ /// Parameters object
517
+ typename UnderlyingIterator::Params params_;
518
+
519
+ public:
520
+
521
+ /// Default ctor
522
+ CUTLASS_HOST_DEVICE
523
+ Params() { }
524
+
525
+ /// Construct the Params object given a pitch-linear tensor's layout
526
+ CUTLASS_HOST_DEVICE
527
+ Params(Layout const &layout)
528
+ : params_(layout::PitchLinear(layout.stride(0))){}
529
+
530
+ /// Construct the Params object given a pitch-linear tensor's layout
531
+ CUTLASS_HOST_DEVICE
532
+ Params(typename UnderlyingIterator::Params::Base const &base)
533
+ : params_(base) {}
534
+ };
535
+
536
+ private:
537
+ //
538
+ // Data members
539
+ //
540
+
541
+ /// Underlying pitch-linear tile iterator
542
+ UnderlyingIterator iterator_;
543
+
544
+ public:
545
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
546
+ /// and thread ID
547
+ CUTLASS_HOST_DEVICE
548
+ PredicatedTileAccessIterator2dThreadTile(
549
+ ///< Precomputed parameters object
550
+ Params const &params,
551
+ ///< Pointer to start of tensor
552
+ Pointer pointer,
553
+ ///< Extent of tensor
554
+ TensorCoord extent,
555
+ ///< ID of each participating thread
556
+ int thread_id,
557
+ ///< Initial offset of threadblock
558
+ TensorCoord const &threadblock_offset)
559
+ : iterator_(params.params_, pointer,
560
+ layout::PitchLinearCoord(extent.row(), extent.column()),
561
+ thread_id,
562
+ layout::PitchLinearCoord(threadblock_offset.row(),
563
+ threadblock_offset.column())) {}
564
+
565
+ /// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset
566
+ CUTLASS_HOST_DEVICE
567
+ PredicatedTileAccessIterator2dThreadTile(
568
+ Params const &params, ///< Precomputed parameters object
569
+ Pointer pointer, ///< Pointer to start of tensor
570
+ TensorCoord extent, ///< Extent of tensor
571
+ int thread_id ///< ID of each participating thread
572
+ )
573
+ : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id,
574
+ make_Coord(0, 0)) {}
575
+
576
+ /// Overrides the internal iteration index
577
+ CUTLASS_HOST_DEVICE
578
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
579
+
580
+ /// Adds a pointer offset in units of Element
581
+ CUTLASS_HOST_DEVICE
582
+ void add_pointer_offset(LongIndex pointer_offset) {
583
+ iterator_.add_pointer_offset(pointer_offset);
584
+ }
585
+
586
+ /// Advances an iterator along logical dimensions of matrix in units of whole
587
+ /// tiles
588
+ CUTLASS_HOST_DEVICE
589
+ void add_tile_offset(TensorCoord const &tile_offset) {
590
+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
591
+ }
592
+
593
+ /// Returns a pointer
594
+ CUTLASS_HOST_DEVICE
595
+ AccessType *get() const {
596
+ return reinterpret_cast<AccessType *>(iterator_.get());
597
+ }
598
+
599
+ /// Advances to the next tile in memory.
600
+ ///
601
+ /// The first time this method is called, predicates are updated, and the
602
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
603
+ /// Subsequent calls are lightweight and must only update the internal
604
+ /// pointer.
605
+ CUTLASS_HOST_DEVICE
606
+ PredicatedTileAccessIterator2dThreadTile &operator++() {
607
+ ++iterator_;
608
+ return *this;
609
+ }
610
+
611
+ /// Advances to the next tile in memory.
612
+ ///
613
+ /// The first time this method is called, predicates are updated, and the
614
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
615
+ /// Subsequent calls are lightweight and must only update the internal
616
+ /// pointer.
617
+ CUTLASS_HOST_DEVICE
618
+ PredicatedTileAccessIterator2dThreadTile operator++(int) {
619
+ PredicatedTileAccessIterator2dThreadTile self(*this);
620
+ operator++();
621
+ return self;
622
+ }
623
+
624
+ /// Clears the predicate set efficiently
625
+ CUTLASS_HOST_DEVICE
626
+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
627
+
628
+ /// Clears the predicate set efficiently
629
+ CUTLASS_HOST_DEVICE
630
+ void enable_mask() { iterator_.enable_mask(); }
631
+
632
+ /// Sets the predicate mask, overriding value stored in predicate iterator
633
+ CUTLASS_HOST_DEVICE
634
+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
635
+
636
+ /// Gets the mask
637
+ CUTLASS_HOST_DEVICE
638
+ void get_mask(Mask &mask) { iterator_.get_mask(mask); }
639
+
640
+ /// Returns whether access is valid or not
641
+ CUTLASS_HOST_DEVICE
642
+ bool valid() {
643
+ return iterator_.valid();
644
+ }
645
+ };
646
+
647
+ ////////////////////////////////////////////////////////////////////////////////
648
+
649
+ /// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data.
650
+ ///
651
+ /// Satisfies: ForwardTileIteratorConcept |
652
+ /// ReadableContiguousTileIteratorConcept |
653
+ /// WriteableContiguousTileIteratorConcept |
654
+ /// MaskedTileIteratorConcept
655
+ ///
656
+ template <typename Shape_, typename Element_, int AdvanceRank,
657
+ typename ThreadMap_, typename AccessType_>
658
+ class PredicatedTileAccessIterator2dThreadTile<Shape_, Element_, layout::RowMajor,
659
+ AdvanceRank, ThreadMap_, AccessType_> {
660
+ public:
661
+ static_assert(
662
+ AdvanceRank == 0 || AdvanceRank == 1,
663
+ "Specialization for pitch-linear iterator may along advance along the "
664
+ "contiguous(rank=0) or strided(rank=1) dimension.");
665
+
666
+ using Shape = Shape_;
667
+ using Element = Element_;
668
+ using Layout = layout::RowMajor;
669
+ static int const kAdvanceRank = AdvanceRank;
670
+ using ThreadMap = ThreadMap_;
671
+ using AccessType = AccessType_;
672
+
673
+ using Index = typename Layout::Index;
674
+ using LongIndex = typename Layout::LongIndex;
675
+
676
+ using TensorRef = TensorRef<Element, Layout>;
677
+ using TensorView = TensorView<Element, Layout>;
678
+ using TensorCoord = typename Layout::TensorCoord;
679
+
680
+ using Pointer = Element *;
681
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
682
+
683
+ using UnderlyingIterator = PredicatedTileAccessIterator2dThreadTile<
684
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
685
+ layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>;
686
+
687
+ /// Predicate vector stores mask to guard accesses
688
+ using Mask = typename UnderlyingIterator::Mask;
689
+
690
+ /// Parameters object is precomputed state and is host-constructible
691
+ class Params {
692
+ private:
693
+ friend PredicatedTileAccessIterator2dThreadTile;
694
+
695
+ /// Parameters object
696
+ typename UnderlyingIterator::Params params_;
697
+
698
+ public:
699
+
700
+ /// Default ctor
701
+ CUTLASS_HOST_DEVICE
702
+ Params() { }
703
+
704
+ /// Construct the Params object given a pitch-linear tensor's layout
705
+ CUTLASS_HOST_DEVICE
706
+ Params(Layout const &layout)
707
+ : params_(layout::PitchLinear(layout.stride(0))){}
708
+
709
+ /// Construct the Params object given a pitch-linear tensor's layout
710
+ CUTLASS_HOST_DEVICE
711
+ Params(typename UnderlyingIterator::Params::Base const &base)
712
+ : params_(base) {}
713
+ };
714
+
715
+ private:
716
+ //
717
+ // Data members
718
+ //
719
+
720
+ /// Underlying pitch-linear tile iterator
721
+ UnderlyingIterator iterator_;
722
+
723
+ public:
724
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
725
+ /// and thread ID
726
+ CUTLASS_HOST_DEVICE
727
+ PredicatedTileAccessIterator2dThreadTile(
728
+ ///< Precomputed parameters object
729
+ Params const &params,
730
+ ///< Pointer to start of tensor
731
+ Pointer pointer,
732
+ ///< Extent of tensor
733
+ TensorCoord extent,
734
+ ///< ID of each participating thread
735
+ int thread_id,
736
+ ///< Initial offset of threadblock
737
+ TensorCoord const &threadblock_offset)
738
+ : iterator_(params.params_, pointer,
739
+ layout::PitchLinearCoord(extent.column(), extent.row()),
740
+ thread_id,
741
+ layout::PitchLinearCoord(threadblock_offset.column(),
742
+ threadblock_offset.row())) {}
743
+
744
+ /// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset
745
+ CUTLASS_HOST_DEVICE
746
+ PredicatedTileAccessIterator2dThreadTile(
747
+ Params const &params, ///< Precomputed parameters object
748
+ Pointer pointer, ///< Pointer to start of tensor
749
+ TensorCoord extent, ///< Extent of tensor
750
+ int thread_id ///< ID of each participating thread
751
+ )
752
+ : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id,
753
+ make_Coord(0, 0)) {}
754
+
755
+ /// Overrides the internal iteration index
756
+ CUTLASS_HOST_DEVICE
757
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
758
+
759
+ /// Adds a pointer offset in units of Element
760
+ CUTLASS_HOST_DEVICE
761
+ void add_pointer_offset(LongIndex pointer_offset) {
762
+ iterator_.add_pointer_offset(pointer_offset);
763
+ }
764
+
765
+ /// Advances an iterator along logical dimensions of matrix in units of whole
766
+ /// tiles
767
+ CUTLASS_HOST_DEVICE
768
+ void add_tile_offset(TensorCoord const &tile_offset) {
769
+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
770
+ }
771
+
772
+ /// Returns a pointer
773
+ CUTLASS_HOST_DEVICE
774
+ AccessType *get() const {
775
+ return reinterpret_cast<AccessType *>(iterator_.get());
776
+ }
777
+
778
+ /// Advances to the next tile in memory.
779
+ ///
780
+ /// The first time this method is called, predicates are updated, and the
781
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
782
+ /// Subsequent calls are lightweight and must only update the internal
783
+ /// pointer.
784
+ CUTLASS_HOST_DEVICE
785
+ PredicatedTileAccessIterator2dThreadTile &operator++() {
786
+ ++iterator_;
787
+ return *this;
788
+ }
789
+
790
+ /// Advances to the next tile in memory.
791
+ ///
792
+ /// The first time this method is called, predicates are updated, and the
793
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
794
+ /// Subsequent calls are lightweight and must only update the internal
795
+ /// pointer.
796
+ CUTLASS_HOST_DEVICE
797
+ PredicatedTileAccessIterator2dThreadTile operator++(int) {
798
+ PredicatedTileAccessIterator2dThreadTile self(*this);
799
+ operator++();
800
+ return self;
801
+ }
802
+
803
+ /// Clears the predicate set efficiently
804
+ CUTLASS_HOST_DEVICE
805
+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
806
+
807
+ /// Clears the predicate set efficiently
808
+ CUTLASS_HOST_DEVICE
809
+ void enable_mask() { iterator_.enable_mask(); }
810
+
811
+ /// Sets the predicate mask, overriding value stored in predicate iterator
812
+ CUTLASS_HOST_DEVICE
813
+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
814
+
815
+ /// Gets the mask
816
+ CUTLASS_HOST_DEVICE
817
+ void get_mask(Mask &mask) { iterator_.get_mask(mask); }
818
+
819
+ /// Returns whether access is valid or not
820
+ CUTLASS_HOST_DEVICE
821
+ bool valid() {
822
+ return iterator_.valid();
823
+ }
824
+ };
825
+
826
+ ////////////////////////////////////////////////////////////////////////////////
827
+
828
+ ////////////////////////////////////////////////////////////////////////////////
829
+
830
+ } // namespace threadblock
831
+ } // namespace transform
832
+ } // namespace cutlass
833
+
834
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief
33
+ */
34
+
35
+ #pragma once
36
+
37
+ #include "cutlass/cutlass.h"
38
+ #include "cutlass/array.h"
39
+ #include "cutlass/detail/helper_macros.hpp"
40
+ #include "cutlass/layout/matrix.h"
41
+ #include "cutlass/layout/pitch_linear.h"
42
+
43
+ /////////////////////////////////////////////////////////////////////////////////////////////////
44
+
45
+ namespace cutlass {
46
+ namespace transform {
47
+ namespace threadblock {
48
+
49
+ /////////////////////////////////////////////////////////////////////////////////////////////////
50
+
51
+ /// Predicated tile access iterator descriptor object containing template dependent state
52
+ struct PredicatedTileAccessIteratorDesc {
53
+
54
+ int element_size_bits = -1;
55
+ int advance_rank = -1;
56
+ layout::PitchLinearCoord threadblock_shape;
57
+ layout::PitchLinearCoord threadmap_iterations;
58
+ layout::PitchLinearCoord threadmap_delta;
59
+
60
+ //
61
+ // Methods
62
+ //
63
+
64
+ PredicatedTileAccessIteratorDesc() = default;
65
+
66
+ CUTLASS_HOST_DEVICE
67
+ PredicatedTileAccessIteratorDesc(
68
+ int element_size_bits_,
69
+ int advance_rank_,
70
+ layout::PitchLinearCoord threadblock_shape_,
71
+ layout::PitchLinearCoord threadmap_iterations_,
72
+ layout::PitchLinearCoord threadmap_delta_
73
+ ):
74
+ element_size_bits(element_size_bits_),
75
+ advance_rank(advance_rank_),
76
+ threadblock_shape(threadblock_shape_),
77
+ threadmap_iterations(threadmap_iterations_),
78
+ threadmap_delta(threadmap_delta_)
79
+ {
80
+ #if 0
81
+ printf("PredicatedTileAccessIteratorDesc(%d, %d, {%d, %d}, {%d, %d}, {%d, %d}})\n",
82
+ element_size_bits,
83
+ advance_rank,
84
+ threadblock_shape.contiguous(), threadblock_shape.strided(),
85
+ threadmap_iterations.contiguous(), threadmap_iterations.strided(),
86
+ threadmap_delta.contiguous(), threadmap_delta.strided());
87
+ #endif
88
+ }
89
+ };
90
+
91
+ /////////////////////////////////////////////////////////////////////////////////////////////////
92
+ /// Helper template to construct an PredicatedTileAccessIteratorDesc from a template
93
+ // dependent state
94
+ template <
95
+ typename Shape, typename Element, typename Layout,
96
+ int AdvanceRank, typename ThreadMap>
97
+ struct MakePredicatedTileAccessIteratorDesc;
98
+ /////////////////////////////////////////////////////////////////////////////////////////////////
99
+
100
+ /// Specialization of PredicatedTileAccessIterator for pitch-linear data.
101
+ template <
102
+ typename Shape, typename Element, int AdvanceRank,
103
+ typename ThreadMap>
104
+ struct MakePredicatedTileAccessIteratorDesc <
105
+ Shape, Element, layout::PitchLinear, AdvanceRank, ThreadMap> {
106
+
107
+ CUTLASS_HOST_DEVICE
108
+ PredicatedTileAccessIteratorDesc operator()() {
109
+
110
+ return PredicatedTileAccessIteratorDesc(
111
+ sizeof_bits<Element>::value,
112
+ AdvanceRank,
113
+ {Shape::kContiguous, Shape::kStrided},
114
+ {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided},
115
+ {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}
116
+ );
117
+ }
118
+
119
+ };
120
+ /////////////////////////////////////////////////////////////////////////////////////////////////
121
+
122
+ /// Specialization of PredicatedTileAccessIterator for column-major data.
123
+ template <
124
+ typename Shape, typename Element, int AdvanceRank,
125
+ typename ThreadMap>
126
+ struct MakePredicatedTileAccessIteratorDesc <
127
+ Shape, Element, layout::ColumnMajor, AdvanceRank, ThreadMap> {
128
+
129
+ static int const kAdvanceRank = AdvanceRank;
130
+
131
+ using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc<
132
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
133
+ layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap>;
134
+
135
+ CUTLASS_HOST_DEVICE
136
+ PredicatedTileAccessIteratorDesc operator()() {
137
+
138
+ return UnderlyingMakeOperator()();
139
+ }
140
+ };
141
+
142
+ /////////////////////////////////////////////////////////////////////////////////////////////////
143
+
144
+ /// Specialization of PredicatedTileAccessIterator for row-major data.
145
+ template <
146
+ typename Shape, typename Element, int AdvanceRank,
147
+ typename ThreadMap>
148
+ struct MakePredicatedTileAccessIteratorDesc <
149
+ Shape, Element, layout::RowMajor, AdvanceRank, ThreadMap> {
150
+
151
+ static int const kAdvanceRank = AdvanceRank;
152
+
153
+ using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc<
154
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
155
+ layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap>;
156
+
157
+ CUTLASS_HOST_DEVICE
158
+ PredicatedTileAccessIteratorDesc operator()() {
159
+
160
+ return UnderlyingMakeOperator()();
161
+ }
162
+ };
163
+
164
+ /////////////////////////////////////////////////////////////////////////////////////////////////
165
+
166
+ /// Specialization of PredicatedTileAccessIterator for column-major interleaved data.
167
+ template <
168
+ typename Shape, typename Element, int AdvanceRank,
169
+ typename ThreadMap, int InterleavedK>
170
+ struct MakePredicatedTileAccessIteratorDesc <
171
+ Shape, Element, layout::ColumnMajorInterleaved<InterleavedK>, AdvanceRank, ThreadMap> {
172
+
173
+ static int const kAdvanceRank = AdvanceRank;
174
+ static int const kInterleavedK = InterleavedK;
175
+
176
+ using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc<
177
+ layout::PitchLinearShape<Shape::kRow * kInterleavedK, Shape::kColumn / kInterleavedK>, Element,
178
+ layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap>;
179
+
180
+ CUTLASS_HOST_DEVICE
181
+ PredicatedTileAccessIteratorDesc operator()() {
182
+
183
+ return UnderlyingMakeOperator()();
184
+ }
185
+ };
186
+
187
+ /////////////////////////////////////////////////////////////////////////////////////////////////
188
+
189
+ /// Specialization of PredicatedTileAccessIterator for roww-major interleaved data.
190
+ template <
191
+ typename Shape, typename Element, int AdvanceRank,
192
+ typename ThreadMap, int InterleavedK>
193
+ struct MakePredicatedTileAccessIteratorDesc <
194
+ Shape, Element, layout::RowMajorInterleaved<InterleavedK>, AdvanceRank, ThreadMap> {
195
+
196
+ static int const kAdvanceRank = AdvanceRank;
197
+ static int const kInterleavedK = InterleavedK;
198
+
199
+ using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc<
200
+ layout::PitchLinearShape<Shape::kColumn * kInterleavedK, Shape::kRow / kInterleavedK>, Element,
201
+ layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap>;
202
+
203
+ CUTLASS_HOST_DEVICE
204
+ PredicatedTileAccessIteratorDesc operator()() {
205
+
206
+ return UnderlyingMakeOperator()();
207
+ }
208
+ };
209
+
210
+ /////////////////////////////////////////////////////////////////////////////////////////////////
211
+
212
+ //
213
+ // Parameters struct
214
+ //
215
+
216
+ struct PredicatedTileAccessIteratorParams {
217
+
218
+ using Index = int32_t;
219
+ using LongIndex = int64_t;
220
+
221
+ //
222
+ // Data members
223
+ //
224
+ /// stride of pitch-linear layout (units of Element)
225
+ LongIndex stride_ = 0;
226
+ /// amount (in byte) to increment pointer to move to next access along
227
+ /// strided dimension
228
+ LongIndex inc_strided_ = 0;
229
+ /// amount (in byte) to increment pointer from last access to first access
230
+ /// of next tile
231
+ LongIndex inc_next_ = 0;
232
+ /// amount (in byte) to increment pointer from first access of current tile
233
+ /// to first access of next tile
234
+ LongIndex inc_advance_ = 0;
235
+
236
+ //
237
+ // Methods
238
+ //
239
+
240
+ CUTLASS_HOST_DEVICE
241
+ Status initialize(LongIndex stride, PredicatedTileAccessIteratorDesc desc) {
242
+ CUTLASS_ASSERT(desc.element_size_bits > 0);
243
+ CUTLASS_ASSERT(desc.advance_rank == 0 || desc.advance_rank == 1);
244
+
245
+ stride_ = stride;
246
+
247
+ inc_strided_ = (LongIndex(stride_) * desc.threadmap_delta.strided()) *
248
+ desc.element_size_bits / 8;
249
+
250
+ if (desc.advance_rank) {
251
+ // advance along strided dimension
252
+ inc_advance_ =
253
+ desc.threadblock_shape.strided() * LongIndex(stride_) * desc.element_size_bits / 8;
254
+ } else {
255
+ // advance along contiguous dimension
256
+ inc_advance_ = desc.threadblock_shape.contiguous() * desc.element_size_bits / 8;
257
+ }
258
+
259
+ inc_next_ = inc_advance_ - LongIndex(desc.threadmap_iterations.strided() - 1) *
260
+ desc.threadmap_delta.strided() * LongIndex(stride_) *
261
+ desc.element_size_bits / 8;
262
+
263
+ return Status::kSuccess;
264
+ }
265
+
266
+ CUTLASS_HOST_DEVICE
267
+ Status initialize(Index stride, PredicatedTileAccessIteratorDesc desc) {
268
+ return initialize(LongIndex(stride), desc);
269
+ }
270
+
271
+ PredicatedTileAccessIteratorParams() = default;
272
+
273
+ CUTLASS_HOST_DEVICE
274
+ PredicatedTileAccessIteratorParams(Index stride, PredicatedTileAccessIteratorDesc desc) {
275
+ initialize(stride, desc);
276
+ }
277
+
278
+ CUTLASS_HOST_DEVICE
279
+ PredicatedTileAccessIteratorParams(LongIndex stride, PredicatedTileAccessIteratorDesc desc) {
280
+ initialize(stride, desc);
281
+ }
282
+ };
283
+
284
+ ////////////////////////////////////////////////////////////////////////////////
285
+
286
+ } // namespace threadblock
287
+ } // namespace transform
288
+ } // namespace cutlass
289
+
290
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h ADDED
@@ -0,0 +1,892 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates calculating the address and predicates to the load of tiles
33
+ from pitch-linear rank=2 tensors.
34
+
35
+ This iterator uses masks to guard out-of-bounds accesses and visits the last
36
+ "residue" tile first, with the objective of minimizing predicate mask updates
37
+ during steady-state operation.
38
+
39
+ A precomputed "Params" object minimizes the amount of state that must be
40
+ stored in registers, and integer addition is used to advance the pointer
41
+ through memory.
42
+
43
+
44
+ */
45
+
46
+ #pragma once
47
+
48
+ #include "cutlass/blas3.h"
49
+ #include "cutlass/layout/matrix.h"
50
+ #include "cutlass/layout/pitch_linear.h"
51
+ #include "cutlass/matrix_shape.h"
52
+ #include "cutlass/predicate_vector.h"
53
+ #include "cutlass/tensor_ref.h"
54
+ #include "cutlass/tensor_view.h"
55
+
56
+ ////////////////////////////////////////////////////////////////////////////////
57
+
58
+ ////////////////////////////////////////////////////////////////////////////////
59
+
60
+ namespace cutlass {
61
+ namespace transform {
62
+ namespace threadblock {
63
+
64
+ ////////////////////////////////////////////////////////////////////////////////
65
+
66
+ /// PredicatedTileAccessIteratorTriangularMatrix
67
+ ///
68
+ template <typename Shape, typename Element, typename Layout,
69
+ int AdvanceRank, typename ThreadMap,
70
+ SideMode kSideMode, FillMode kFillMode, DiagType kDiagType,
71
+ typename AccessType>
72
+ class PredicatedTileAccessIteratorTriangularMatrix;
73
+
74
+ ////////////////////////////////////////////////////////////////////////////////
75
+
76
+ /// Specialization of PredicatedTileAccessIteratorTriangularMatrix for pitch-linear data.
77
+ ///
78
+ template <typename Shape_, typename Element_, int AdvanceRank,
79
+ typename ThreadMap_, SideMode kSideMode, FillMode kFillMode, DiagType kDiagType, typename AccessType_>
80
+ class PredicatedTileAccessIteratorTriangularMatrix<Shape_, Element_, layout::PitchLinear,
81
+ AdvanceRank, ThreadMap_, kSideMode, kFillMode, kDiagType, AccessType_> {
82
+ public:
83
+ static_assert(
84
+ AdvanceRank == 0 || AdvanceRank == 1,
85
+ "Specialization for pitch-linear iterator may along advance along the "
86
+ "contiguous(rank=0) or strided(rank=1) dimension.");
87
+
88
+ using Shape = Shape_;
89
+ using Element = Element_;
90
+ using Layout = layout::PitchLinear;
91
+ static int const kAdvanceRank = AdvanceRank;
92
+ using ThreadMap = ThreadMap_;
93
+ using AccessType = AccessType_;
94
+
95
+ using Index = typename Layout::Index;
96
+ using LongIndex = typename Layout::LongIndex;
97
+ using StrideIndex = typename Layout::Stride::Index;
98
+
99
+ using TensorRef = TensorRef<Element, Layout>;
100
+ using TensorView = TensorView<Element, Layout>;
101
+ using TensorCoord = typename Layout::TensorCoord;
102
+
103
+ using Pointer = Element *;
104
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
105
+
106
+ static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
107
+
108
+ using CompareOp = typename TrMatrixCompareOp<kFillMode, kDiagType>::Type;
109
+
110
+ static_assert( kFillMode == FillMode::kFull ||
111
+ ((kFillMode == FillMode::kLower || kFillMode == FillMode::kUpper) && AccessType::kElements == 1),
112
+ "BLAS3 iterator for the triangular/symmetric matrix must use AccessType::kElements as 1");
113
+
114
+ static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
115
+ "Vectors implied by the thread map must be divisible by the access type.");
116
+
117
+ static int const kPredicatesPerByte = 4;
118
+ static int const kPredicatesPerWord = 4 * kPredicatesPerByte;
119
+
120
+ static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector;
121
+
122
+ /// Number of 32b words containing predicates
123
+ static int const kPredicateByteCount =
124
+ (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte;
125
+ static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4;
126
+
127
+ static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u;
128
+
129
+ static_assert(kPredicateWordCount <= 4, "Too many predicates.");
130
+
131
+ /// Predicate vector stores mask to guard accesses
132
+ using Mask = Array<uint32_t, kPredicateWordCount>;
133
+
134
+ /// Parameters object is precomputed state and is host-constructible
135
+ class Params {
136
+ public:
137
+ friend PredicatedTileAccessIteratorTriangularMatrix;
138
+
139
+ private:
140
+ /// stride of pitch-linear layout (units of Element)
141
+ StrideIndex stride_;
142
+ /// (true) pitch-linear layout is mapped to row-major matrix
143
+ /// (false) pitch-linear layout is mapped to column-major matrix
144
+ bool is_row_major_;
145
+ /// for vectorized access across the diagonal boundary guard condition is
146
+ /// checked for the element on the boundary
147
+ int access_diagonal_boundary_;
148
+ /// amount (in byte) to increment pointer to move to next access along
149
+ /// strided dimension
150
+ LongIndex inc_strided_;
151
+ /// amount (in byte) to increment pointer from last access to first access
152
+ /// of next tile
153
+ LongIndex inc_next_;
154
+ /// amount (in byte) to increment pointer from first access of current tile
155
+ /// to first access of next tile
156
+ LongIndex inc_advance_;
157
+
158
+ public:
159
+
160
+ // Default ctor
161
+ CUTLASS_HOST_DEVICE
162
+ Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0), is_row_major_(false), access_diagonal_boundary_(0) { }
163
+
164
+ /// Construct the Params object given a pitch-linear tensor's layout
165
+ CUTLASS_HOST_DEVICE
166
+ Params(Layout const &layout, bool is_row_major, int access_diagonal_boundary) :
167
+ stride_(layout.stride(0)), is_row_major_(is_row_major), access_diagonal_boundary_(access_diagonal_boundary) {
168
+
169
+ inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) *
170
+ sizeof_bits<Element>::value / 8;
171
+
172
+ if (kAdvanceRank) {
173
+ // advance along strided dimension
174
+ inc_advance_ =
175
+ Shape::kStrided * LongIndex(stride_) * sizeof_bits<Element>::value / 8;
176
+ } else {
177
+ // advance along contiguous dimension
178
+ inc_advance_ = Shape::kContiguous * sizeof_bits<Element>::value / 8;
179
+ }
180
+
181
+ inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) *
182
+ ThreadMap::Delta::kStrided * LongIndex(stride_) *
183
+ sizeof_bits<Element>::value / 8;
184
+
185
+ };
186
+
187
+
188
+ };
189
+
190
+ private:
191
+ /// Internal pointer type permits fast address arithmetic
192
+ using BytePointer = char *;
193
+
194
+ private:
195
+ //
196
+ // Data members
197
+ //
198
+
199
+ /// Parameters object with precomputed internal state
200
+ Params const &params_;
201
+
202
+ /// Internal pointer to first access of tile
203
+ BytePointer pointer_;
204
+
205
+ /// Guard predicates
206
+ uint32_t predicates_[kPredicateWordCount];
207
+
208
+ /// Track global memory addresses on the diagonal
209
+ /// To ignore imag part for diagonal elements of hermitian matrices
210
+ uint32_t predicates_onDiag_[kPredicateWordCount];
211
+
212
+ /// Size of tensor
213
+ TensorCoord extent_;
214
+
215
+ /// Initial offset for each thread
216
+ TensorCoord thread_offset_;
217
+
218
+ /// Iteration along vectors implied by the thread map
219
+ int iteration_vector_;
220
+
221
+ /// Iteration in the contiguous dimension
222
+ int iteration_contiguous_;
223
+
224
+ /// Iteration in the strided dimension
225
+ int iteration_strided_;
226
+
227
+ private:
228
+ /// Computes predicates based on internally tracked per-thread offset.
229
+ CUTLASS_DEVICE
230
+ void compute_predicates_(
231
+ /// Extent of the matrix window
232
+ TensorCoord extent) {
233
+
234
+ CUTLASS_PRAGMA_UNROLL
235
+ for (int i = 0; i < kPredicateWordCount; ++i) {
236
+ predicates_[i] = 0u;
237
+ predicates_onDiag_[i] = 0u;
238
+ }
239
+
240
+ CompareOp compare_op;
241
+
242
+ CUTLASS_PRAGMA_UNROLL
243
+ for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) {
244
+
245
+ int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
246
+
247
+ int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
248
+
249
+ int c = access_residual / kAccessesPerVector;
250
+ int v = access_residual % kAccessesPerVector;
251
+
252
+ TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements,
253
+ s * ThreadMap::Delta::kStrided);
254
+
255
+ TensorCoord coord = thread_offset_ + iteration_coord;
256
+
257
+ bool guard;
258
+ bool onDiag = false;
259
+
260
+ guard = ((coord.strided() < extent.strided()) &&
261
+ (coord.contiguous() < extent.contiguous()));
262
+
263
+
264
+ // guard access on the wrong side of the triagular matrix diagonal
265
+ if (kFillMode == FillMode::kLower || kFillMode == FillMode::kUpper) {
266
+ coord += TensorCoord{params_.access_diagonal_boundary_, 0};
267
+
268
+ bool triagular_guard_row_major = compare_op(coord.strided(), coord.contiguous()) | !params_.is_row_major_;
269
+ bool triagular_guard_col_major = compare_op(coord.contiguous(), coord.strided()) | params_.is_row_major_;
270
+
271
+ guard = guard && triagular_guard_row_major && triagular_guard_col_major;
272
+
273
+ if (kDiagType == DiagType::kUnit) {
274
+ onDiag = (guard && coord.strided() == coord.contiguous()) ? true : false;
275
+ }
276
+ }
277
+
278
+ int pred_idx_onDiag = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s);
279
+ int word_idx_onDiag = pred_idx_onDiag / kPredicatesPerWord;
280
+ int residual_onDiag = pred_idx_onDiag % kPredicatesPerWord;
281
+ int byte_idx_onDiag = residual_onDiag / kPredicatesPerByte;
282
+ int bit_idx_onDiag = residual_onDiag % kPredicatesPerByte;
283
+
284
+ predicates_onDiag_[word_idx_onDiag] |= (unsigned(onDiag) << (byte_idx_onDiag * 8 + bit_idx_onDiag));
285
+
286
+ int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s);
287
+
288
+ int word_idx = pred_idx / kPredicatesPerWord;
289
+ int residual = pred_idx % kPredicatesPerWord;
290
+ int byte_idx = residual / kPredicatesPerByte;
291
+ int bit_idx = residual % kPredicatesPerByte;
292
+
293
+ predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx));
294
+
295
+ }
296
+
297
+ }
298
+
299
+ public:
300
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
301
+ /// and thread ID
302
+ CUTLASS_HOST_DEVICE
303
+ PredicatedTileAccessIteratorTriangularMatrix(
304
+ /// Precomputed parameters object
305
+ Params const &params,
306
+ /// Pointer to start of tensor
307
+ Pointer pointer,
308
+ /// Extent of tensor
309
+ TensorCoord extent,
310
+ /// ID of each participating thread
311
+ int thread_id,
312
+ /// Initial offset of threadblock
313
+ TensorCoord const &threadblock_offset)
314
+ : params_(params),
315
+ pointer_(reinterpret_cast<BytePointer>(const_cast<NonConstPointer>(pointer))),
316
+ extent_(extent) {
317
+
318
+
319
+ // Per-thread offset in logical coordinates of tensor
320
+ thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id);
321
+
322
+ // update internal pointers
323
+ Layout layout(params_.stride_);
324
+ add_pointer_offset(layout(thread_offset_));
325
+
326
+ compute_predicates_(extent_);
327
+
328
+ set_iteration_index(0);
329
+ }
330
+
331
+ /// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset
332
+ CUTLASS_HOST_DEVICE
333
+ PredicatedTileAccessIteratorTriangularMatrix(
334
+ /// Precomputed parameters object
335
+ Params const &params,
336
+ /// Pointer to start of tensor
337
+ Pointer pointer,
338
+ /// Extent of tensor
339
+ TensorCoord extent,
340
+ ///< ID of each participating thread
341
+ int thread_id)
342
+ : PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id,
343
+ make_Coord(0, 0)) {}
344
+
345
+ /// Overrides the internal iteration index
346
+ CUTLASS_HOST_DEVICE
347
+ void set_iteration_index(int index) {
348
+
349
+ iteration_vector_ = index % kAccessesPerVector;
350
+ int residual_access = index / kAccessesPerVector;
351
+
352
+ iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
353
+ iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
354
+
355
+ }
356
+
357
+ /// Adds a pointer offset in units of Element
358
+ CUTLASS_HOST_DEVICE
359
+ void add_pointer_offset(LongIndex pointer_offset) {
360
+ pointer_ += sizeof_bits<Element>::value * pointer_offset / 8;
361
+ }
362
+
363
+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles
364
+ CUTLASS_DEVICE
365
+ void add_tile_offset(TensorCoord const &tile_offset) {
366
+
367
+ if (kAdvanceRank) {
368
+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided());
369
+ pointer_ += Shape::kContiguous * tile_offset.contiguous();
370
+ thread_offset_ += TensorCoord{0, Shape::kStrided * tile_offset.strided()};
371
+ } else {
372
+ pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous());
373
+ pointer_ += Shape::kStrided * tile_offset.strided();
374
+ thread_offset_ += TensorCoord{Shape::kContiguous * tile_offset.contiguous(), 0};
375
+ }
376
+
377
+ compute_predicates_(extent_);
378
+ }
379
+
380
+ /// Returns a pointer
381
+ CUTLASS_HOST_DEVICE
382
+ AccessType *get() const {
383
+ return reinterpret_cast<AccessType *>(
384
+ pointer_ +
385
+ iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value) / 8) + iteration_vector_;
386
+ }
387
+
388
+ /// Increment and return an instance to self.
389
+ CUTLASS_HOST_DEVICE
390
+ PredicatedTileAccessIteratorTriangularMatrix &operator++() {
391
+
392
+ ++iteration_vector_;
393
+ if (iteration_vector_ < kAccessesPerVector) {
394
+ return *this;
395
+ }
396
+
397
+ iteration_vector_ = 0;
398
+ ++iteration_contiguous_;
399
+
400
+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
401
+ return *this;
402
+ }
403
+
404
+ // Enter here only if (iteration_contiguous_ ==
405
+ // ThreadMap::Iteration::kContiguous)
406
+ iteration_contiguous_ = 0;
407
+ ++iteration_strided_;
408
+
409
+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
410
+ pointer_ += params_.inc_strided_;
411
+ return *this;
412
+ }
413
+
414
+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
415
+ // which means we enter the next tile.
416
+ iteration_strided_ = 0;
417
+
418
+ // advance to next tile
419
+ pointer_ += params_.inc_next_;
420
+
421
+ // now return to start tile - if the iterator is subsequently advanced, this
422
+ // subtraction as well as the subsequent integer addition are both elided by
423
+ // the compiler.
424
+ pointer_ -= params_.inc_advance_;
425
+
426
+ return *this;
427
+ }
428
+
429
+ /// Increment and return an instance to self.
430
+ CUTLASS_HOST_DEVICE
431
+ PredicatedTileAccessIteratorTriangularMatrix operator++(int) {
432
+ PredicatedTileAccessIteratorTriangularMatrix self(*this);
433
+ operator++();
434
+ return self;
435
+ }
436
+
437
+ /// Clears the predicate set efficiently
438
+ CUTLASS_HOST_DEVICE
439
+ void clear_mask(bool enable = true) {
440
+ CUTLASS_PRAGMA_UNROLL
441
+ for (int i = 0; i < kPredicateWordCount; ++i) {
442
+ predicates_[i] = enable ? 0u : predicates_[i];
443
+ }
444
+
445
+ }
446
+
447
+ /// Clears the predicate set efficiently
448
+ CUTLASS_HOST_DEVICE
449
+ void enable_mask() {
450
+ CUTLASS_PRAGMA_UNROLL
451
+ for (int i = 0; i < kPredicateWordCount; ++i) {
452
+ predicates_[i] = 0xffffffff;
453
+ }
454
+ }
455
+
456
+ /// Sets the predicate mask, overriding value stored in predicate iterator
457
+ CUTLASS_HOST_DEVICE
458
+ void set_mask(Mask const &mask) {
459
+ CUTLASS_PRAGMA_UNROLL
460
+ for (int i = 0; i < kPredicateWordCount; ++i) {
461
+ predicates_[i] = mask[i];
462
+ }
463
+
464
+ }
465
+
466
+ /// Gets the mask
467
+ CUTLASS_HOST_DEVICE
468
+ void get_mask(Mask &mask) {
469
+ CUTLASS_PRAGMA_UNROLL
470
+ for (int i = 0; i < kPredicateWordCount; ++i) {
471
+ mask[i] = predicates_[i];
472
+ }
473
+ }
474
+
475
+ /// Return if the address in on the diagonal
476
+ CUTLASS_HOST_DEVICE
477
+ bool getOnDiag() {
478
+ int pred_idx =
479
+ iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous);
480
+
481
+ int word_idx = pred_idx / kPredicatesPerWord;
482
+ int residual = pred_idx % kPredicatesPerWord;
483
+ int byte_idx = residual / kPredicatesPerByte;
484
+ int bit_idx = residual % kPredicatesPerByte;
485
+
486
+ bool pred = (predicates_onDiag_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0;
487
+ return pred;
488
+ }
489
+
490
+ /// Returns whether access is valid or not
491
+ CUTLASS_HOST_DEVICE
492
+ bool valid() {
493
+
494
+
495
+ int pred_idx =
496
+ iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous);
497
+
498
+ int word_idx = pred_idx / kPredicatesPerWord;
499
+ int residual = pred_idx % kPredicatesPerWord;
500
+ int byte_idx = residual / kPredicatesPerByte;
501
+ int bit_idx = residual % kPredicatesPerByte;
502
+
503
+ bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0;
504
+ return pred;
505
+
506
+
507
+ //return true;
508
+ }
509
+ };
510
+
511
+ ////////////////////////////////////////////////////////////////////////////////
512
+
513
+ /// Specialization of PredicatedTileAccessIteratorTriangularMatrix for column-major data.
514
+ ///
515
+ /// Satisfies: ForwardTileIteratorConcept |
516
+ /// ReadableContiguousTileIteratorConcept |
517
+ /// WriteableContiguousTileIteratorConcept |
518
+ /// MaskedTileIteratorConcept
519
+ ///
520
+ template <typename Shape_, typename Element_, int AdvanceRank, typename ThreadMap_,
521
+ SideMode kSideMode, FillMode kFillMode, DiagType kDiagType,
522
+ typename AccessType_>
523
+ class PredicatedTileAccessIteratorTriangularMatrix<Shape_, Element_, layout::ColumnMajor,
524
+ AdvanceRank, ThreadMap_, kSideMode, kFillMode, kDiagType,
525
+ AccessType_> {
526
+ public:
527
+ static_assert(
528
+ AdvanceRank == 0 || AdvanceRank == 1,
529
+ "Specialization for pitch-linear iterator may along advance along the "
530
+ "contiguous(rank=0) or strided(rank=1) dimension.");
531
+
532
+ using Shape = Shape_;
533
+ using Element = Element_;
534
+ using Layout = layout::ColumnMajor;
535
+ static int const kAdvanceRank = AdvanceRank;
536
+ using ThreadMap = ThreadMap_;
537
+ using AccessType = AccessType_;
538
+
539
+ using Index = typename Layout::Index;
540
+ using LongIndex = typename Layout::LongIndex;
541
+
542
+ using TensorRef = TensorRef<Element, Layout>;
543
+ using TensorView = TensorView<Element, Layout>;
544
+ using TensorCoord = typename Layout::TensorCoord;
545
+
546
+ using Pointer = Element *;
547
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
548
+
549
+ using UnderlyingIterator = PredicatedTileAccessIteratorTriangularMatrix<
550
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
551
+ layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap,
552
+ kSideMode, kFillMode, kDiagType, AccessType>;
553
+
554
+ /// Predicate vector stores mask to guard accesses
555
+ using Mask = typename UnderlyingIterator::Mask;
556
+
557
+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
558
+
559
+ static int const kAccessDiagonalBoundary =
560
+ (kFillMode == FillMode::kLower) ? (AccessType::kElements - 1) : 0;
561
+
562
+ /// Parameters object is precomputed state and is host-constructible
563
+ class Params {
564
+ private:
565
+ friend PredicatedTileAccessIteratorTriangularMatrix;
566
+
567
+ /// Parameters object
568
+ typename UnderlyingIterator::Params params_;
569
+
570
+ public:
571
+
572
+ /// Default ctor
573
+ CUTLASS_HOST_DEVICE
574
+ Params() { }
575
+
576
+ /// Construct the Params object given a pitch-linear tensor's layout
577
+ CUTLASS_HOST_DEVICE
578
+ Params(Layout const &layout)
579
+ : params_(layout::PitchLinear(layout.stride(0)), false, kAccessDiagonalBoundary){};
580
+ };
581
+
582
+ private:
583
+ //
584
+ // Data members
585
+ //
586
+
587
+ /// Underlying pitch-linear tile iterator
588
+ UnderlyingIterator iterator_;
589
+
590
+ public:
591
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
592
+ /// and thread ID
593
+ CUTLASS_HOST_DEVICE
594
+ PredicatedTileAccessIteratorTriangularMatrix(
595
+ ///< Precomputed parameters object
596
+ Params const &params,
597
+ ///< Pointer to start of tensor
598
+ Pointer pointer,
599
+ ///< Extent of tensor
600
+ TensorCoord extent,
601
+ ///< ID of each participating thread
602
+ int thread_id,
603
+ ///< Initial offset of threadblock
604
+ TensorCoord const &threadblock_offset)
605
+ : iterator_(params.params_, pointer,
606
+ layout::PitchLinearCoord(extent.row(), extent.column()),
607
+ thread_id,
608
+ layout::PitchLinearCoord(threadblock_offset.row(),
609
+ threadblock_offset.column())) {}
610
+
611
+ /// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset
612
+ CUTLASS_HOST_DEVICE
613
+ PredicatedTileAccessIteratorTriangularMatrix(
614
+ Params const &params, ///< Precomputed parameters object
615
+ Pointer pointer, ///< Pointer to start of tensor
616
+ TensorCoord extent, ///< Extent of tensor
617
+ int thread_id ///< ID of each participating thread
618
+ )
619
+ : PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id,
620
+ make_Coord(0, 0)) {}
621
+
622
+ /// Overrides the internal iteration index
623
+ CUTLASS_HOST_DEVICE
624
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
625
+
626
+ /// Adds a pointer offset in units of Element
627
+ CUTLASS_HOST_DEVICE
628
+ void add_pointer_offset(LongIndex pointer_offset) {
629
+ iterator_.add_pointer_offset(pointer_offset);
630
+ }
631
+
632
+ /// Advances an iterator along logical dimensions of matrix in units of whole
633
+ /// tiles
634
+ CUTLASS_HOST_DEVICE
635
+ void add_tile_offset(TensorCoord const &tile_offset) {
636
+ iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
637
+ }
638
+
639
+ /// Returns a pointer
640
+ CUTLASS_HOST_DEVICE
641
+ AccessType *get() const {
642
+ return reinterpret_cast<AccessType *>(iterator_.get());
643
+ }
644
+
645
+ /// Advances to the next tile in memory.
646
+ ///
647
+ /// The first time this method is called, predicates are updated, and the
648
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
649
+ /// Subsequent calls are lightweight and must only update the internal
650
+ /// pointer.
651
+ CUTLASS_HOST_DEVICE
652
+ PredicatedTileAccessIteratorTriangularMatrix &operator++() {
653
+ ++iterator_;
654
+ return *this;
655
+ }
656
+
657
+ /// Advances to the next tile in memory.
658
+ ///
659
+ /// The first time this method is called, predicates are updated, and the
660
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
661
+ /// Subsequent calls are lightweight and must only update the internal
662
+ /// pointer.
663
+ CUTLASS_HOST_DEVICE
664
+ PredicatedTileAccessIteratorTriangularMatrix operator++(int) {
665
+ PredicatedTileAccessIteratorTriangularMatrix self(*this);
666
+ operator++();
667
+ return self;
668
+ }
669
+
670
+ /// Clears the predicate set efficiently
671
+ CUTLASS_HOST_DEVICE
672
+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
673
+
674
+ /// Clears the predicate set efficiently
675
+ CUTLASS_HOST_DEVICE
676
+ void enable_mask() { iterator_.enable_mask(); }
677
+
678
+ /// Sets the predicate mask, overriding value stored in predicate iterator
679
+ CUTLASS_HOST_DEVICE
680
+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
681
+
682
+ /// Gets the mask
683
+ CUTLASS_HOST_DEVICE
684
+ void get_mask(Mask &mask) { iterator_.get_mask(mask); }
685
+
686
+ /// Return if the address in on the diagonal
687
+ CUTLASS_HOST_DEVICE
688
+ bool getOnDiag() {
689
+ return iterator_.getOnDiag();
690
+ }
691
+
692
+ /// Returns whether access is valid or not
693
+ CUTLASS_HOST_DEVICE
694
+ bool valid() {
695
+ return iterator_.valid();
696
+ }
697
+ };
698
+
699
+ ////////////////////////////////////////////////////////////////////////////////
700
+
701
+ /// Specialization of PredicatedTileAccessIteratorTriangularMatrix for row-major data.
702
+ ///
703
+ /// Satisfies: ForwardTileIteratorConcept |
704
+ /// ReadableContiguousTileIteratorConcept |
705
+ /// WriteableContiguousTileIteratorConcept |
706
+ /// MaskedTileIteratorConcept
707
+ ///
708
+ template <typename Shape_, typename Element_, int AdvanceRank, typename ThreadMap_,
709
+ SideMode kSideMode, FillMode kFillMode, DiagType kDiagType,
710
+ typename AccessType_>
711
+ class PredicatedTileAccessIteratorTriangularMatrix<Shape_, Element_, layout::RowMajor, AdvanceRank, ThreadMap_,
712
+ kSideMode, kFillMode, kDiagType, AccessType_> {
713
+ public:
714
+ static_assert(
715
+ AdvanceRank == 0 || AdvanceRank == 1,
716
+ "Specialization for pitch-linear iterator may along advance along the "
717
+ "contiguous(rank=0) or strided(rank=1) dimension.");
718
+
719
+ using Shape = Shape_;
720
+ using Element = Element_;
721
+ using Layout = layout::RowMajor;
722
+ static int const kAdvanceRank = AdvanceRank;
723
+ using ThreadMap = ThreadMap_;
724
+ using AccessType = AccessType_;
725
+
726
+ using Index = typename Layout::Index;
727
+ using LongIndex = typename Layout::LongIndex;
728
+
729
+ using TensorRef = TensorRef<Element, Layout>;
730
+ using TensorView = TensorView<Element, Layout>;
731
+ using TensorCoord = typename Layout::TensorCoord;
732
+
733
+ using Pointer = Element *;
734
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
735
+
736
+ using UnderlyingIterator = PredicatedTileAccessIteratorTriangularMatrix<
737
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
738
+ layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap,
739
+ kSideMode, kFillMode, kDiagType, AccessType>;
740
+
741
+ static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
742
+
743
+ static int const kAccessDiagonalBoundary =
744
+ (kFillMode == FillMode::kUpper) ? (AccessType::kElements - 1) : 0;
745
+
746
+ /// Predicate vector stores mask to guard accesses
747
+ using Mask = typename UnderlyingIterator::Mask;
748
+
749
+ /// Parameters object is precomputed state and is host-constructible
750
+ class Params {
751
+ private:
752
+ friend PredicatedTileAccessIteratorTriangularMatrix;
753
+
754
+ /// Parameters object
755
+ typename UnderlyingIterator::Params params_;
756
+
757
+ public:
758
+
759
+ /// Default ctor
760
+ CUTLASS_HOST_DEVICE
761
+ Params() { }
762
+
763
+ /// Construct the Params object given a pitch-linear tensor's layout
764
+ CUTLASS_HOST_DEVICE
765
+ Params(Layout const &layout)
766
+ : params_(layout::PitchLinear(layout.stride(0)), true, kAccessDiagonalBoundary){};
767
+ };
768
+
769
+ private:
770
+ //
771
+ // Data members
772
+ //
773
+
774
+ /// Underlying pitch-linear tile iterator
775
+ UnderlyingIterator iterator_;
776
+
777
+ public:
778
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
779
+ /// and thread ID
780
+ CUTLASS_HOST_DEVICE
781
+ PredicatedTileAccessIteratorTriangularMatrix(
782
+ ///< Precomputed parameters object
783
+ Params const &params,
784
+ ///< Pointer to start of tensor
785
+ Pointer pointer,
786
+ ///< Extent of tensor
787
+ TensorCoord extent,
788
+ ///< ID of each participating thread
789
+ int thread_id,
790
+ ///< Initial offset of threadblock
791
+ TensorCoord const &threadblock_offset)
792
+ : iterator_(params.params_, pointer,
793
+ layout::PitchLinearCoord(extent.column(), extent.row()),
794
+ thread_id,
795
+ layout::PitchLinearCoord(threadblock_offset.column(),
796
+ threadblock_offset.row())) {}
797
+
798
+ /// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset
799
+ CUTLASS_HOST_DEVICE
800
+ PredicatedTileAccessIteratorTriangularMatrix(
801
+ Params const &params, ///< Precomputed parameters object
802
+ Pointer pointer, ///< Pointer to start of tensor
803
+ TensorCoord extent, ///< Extent of tensor
804
+ int thread_id ///< ID of each participating thread
805
+ )
806
+ : PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id,
807
+ make_Coord(0, 0)) {}
808
+
809
+ /// Overrides the internal iteration index
810
+ CUTLASS_HOST_DEVICE
811
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
812
+
813
+ /// Adds a pointer offset in units of Element
814
+ CUTLASS_HOST_DEVICE
815
+ void add_pointer_offset(LongIndex pointer_offset) {
816
+ iterator_.add_pointer_offset(pointer_offset);
817
+ }
818
+
819
+ /// Advances an iterator along logical dimensions of matrix in units of whole
820
+ /// tiles
821
+ CUTLASS_HOST_DEVICE
822
+ void add_tile_offset(TensorCoord const &tile_offset) {
823
+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
824
+ }
825
+
826
+ /// Returns a pointer
827
+ CUTLASS_HOST_DEVICE
828
+ AccessType *get() const {
829
+ return reinterpret_cast<AccessType *>(iterator_.get());
830
+ }
831
+
832
+ /// Advances to the next tile in memory.
833
+ ///
834
+ /// The first time this method is called, predicates are updated, and the
835
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
836
+ /// Subsequent calls are lightweight and must only update the internal
837
+ /// pointer.
838
+ CUTLASS_HOST_DEVICE
839
+ PredicatedTileAccessIteratorTriangularMatrix &operator++() {
840
+ ++iterator_;
841
+ return *this;
842
+ }
843
+
844
+ /// Advances to the next tile in memory.
845
+ ///
846
+ /// The first time this method is called, predicates are updated, and the
847
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
848
+ /// Subsequent calls are lightweight and must only update the internal
849
+ /// pointer.
850
+ CUTLASS_HOST_DEVICE
851
+ PredicatedTileAccessIteratorTriangularMatrix operator++(int) {
852
+ PredicatedTileAccessIteratorTriangularMatrix self(*this);
853
+ operator++();
854
+ return self;
855
+ }
856
+
857
+ /// Clears the predicate set efficiently
858
+ CUTLASS_HOST_DEVICE
859
+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
860
+
861
+ /// Clears the predicate set efficiently
862
+ CUTLASS_HOST_DEVICE
863
+ void enable_mask() { iterator_.enable_mask(); }
864
+
865
+ /// Sets the predicate mask, overriding value stored in predicate iterator
866
+ CUTLASS_HOST_DEVICE
867
+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
868
+
869
+ /// Gets the mask
870
+ CUTLASS_HOST_DEVICE
871
+ void get_mask(Mask &mask) { iterator_.get_mask(mask); }
872
+
873
+ /// Return if the address in on the diagonal
874
+ CUTLASS_HOST_DEVICE
875
+ bool getOnDiag() {
876
+ return iterator_.getOnDiag();
877
+ }
878
+
879
+ /// Returns whether access is valid or not
880
+ CUTLASS_HOST_DEVICE
881
+ bool valid() {
882
+ return iterator_.valid();
883
+ }
884
+ };
885
+
886
+ ////////////////////////////////////////////////////////////////////////////////
887
+
888
+ } // namespace threadblock
889
+ } // namespace transform
890
+ } // namespace cutlass
891
+
892
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h ADDED
@@ -0,0 +1,1887 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors.
33
+
34
+ This iterator uses masks to guard out-of-bounds accesses. The first tile this
35
+ iterator visits maybe partial, then the remaining tiles are complete. So, we
36
+ only need to compute the predicates twice, once before the first tile and
37
+ once for the remaining full tiles which can share the same predicates.
38
+
39
+ A precomputed "Params" object minimizes the amount of state that must be stored in registers,
40
+ and integer addition is used to advance the pointer through memory.
41
+ */
42
+
43
+ #pragma once
44
+
45
+ #include "cutlass/arch/memory.h"
46
+ #include "cutlass/transform/threadblock/predicated_tile_access_iterator.h"
47
+
48
+ ////////////////////////////////////////////////////////////////////////////////
49
+
50
+ namespace cutlass {
51
+ namespace transform {
52
+ namespace threadblock {
53
+
54
+ ////////////////////////////////////////////////////////////////////////////////
55
+
56
+ /// PredicatedTileIterator
57
+ ///
58
+ /// Satisfies: ForwardTileIteratorConcept |
59
+ /// ReadableContiguousTileIteratorConcept |
60
+ /// WriteableContiguousTileIteratorConcept |
61
+ /// MaskedTileIteratorConcept
62
+ ///
63
+ /// Regular tile iterator using a precomputed control structure to minimize register liveness
64
+ /// and integer arithmetic.
65
+ ///
66
+ /// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed.
67
+ ///
68
+ /// Base pointer and tensor extents may be specified at the time the iterator is constructed.
69
+ /// Subsequently, they are assumed to be immutable.
70
+ ///
71
+ /// Adding a logical coordinate offset may be performed at the time the iterator is constructed.
72
+ /// Subsequent additions to logical coordinate offset may be performed but are relatively expensive.
73
+ ///
74
+ /// Visitation order is intended to first visit a "residual" tile that may be partially full in
75
+ /// both the advance dimension and the steady-state dimension. This is assumed to be the last
76
+ /// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to
77
+ /// the first tile that is full in the advance dimension and recomputes predicates. Subsequent
78
+ /// accesses may be performed without updating internal predicates and are efficient in terms of
79
+ /// live register state and pointer arithmetic instructions.
80
+ ///
81
+ /// To be efficient, this assumes the iterator will be dereferenced and advanced at least once
82
+ /// outside any looping structure to minimize integer arithmetic.
83
+ ///
84
+ /// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing
85
+ /// the iterator.
86
+ ///
87
+ ///
88
+ /// Example:
89
+ ///
90
+ /// An efficient pipeline structure may be constructed as follows:
91
+ ///
92
+ // template <typename Iterator>
93
+ // __global__ void kernel(
94
+ // typename Iterator::Params params,
95
+ // typename Iterator::Element *ptr,
96
+ // TensorCoord extent) {
97
+ //
98
+ // typename Iterator::Fragment fragment;
99
+ //
100
+ // TensorCoord threadblock_offset(0, 0);
101
+ //
102
+ // Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets);
103
+ //
104
+ //
105
+ // fragment = *iter; // load "residue" tile first
106
+ // ++iter; // advance to first "steady state" tile and update internal masks
107
+ //
108
+ //
109
+ // #pragma unroll
110
+ // for (int i = Remaining - 1; i >= 0; --i) {
111
+ //
112
+ // f(fragment);
113
+ //
114
+ // if (!i) {
115
+ // iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs.
116
+ // }
117
+ //
118
+ // fragment = *iter; // load tile during "steady state" phase
119
+ // ++iter; // advance to next tile - lightweight due to steady-state masks
120
+ // }
121
+ // }
122
+ //
123
+ // void host(TensorView<Element, 2, layout::PitchLinear> view) {
124
+ //
125
+ // using Iterator = transform::threadblock::PredicatedTileIterator;
126
+ //
127
+ // typename Iterator::Params params(view.layout());
128
+ //
129
+ // kernel<Iterator>(params, view.data());
130
+ // }
131
+ ///
132
+ ///
133
+ template <
134
+ typename Shape,
135
+ typename Element,
136
+ typename Layout,
137
+ int AdvanceRank,
138
+ typename ThreadMap,
139
+ int AccessSize = ThreadMap::kElementsPerAccess,
140
+ bool Gather = false,
141
+ typename PermuteLayout = layout::NoPermute
142
+ >
143
+ class PredicatedTileIterator;
144
+
145
+ ////////////////////////////////////////////////////////////////////////////////
146
+
147
+ /// Specialization of PredicatedTileIterator for pitch-linear data.
148
+ ///
149
+ /// Satisfies: ForwardTileIteratorConcept |
150
+ /// ReadableContiguousTileIteratorConcept |
151
+ /// WriteableContiguousTileIteratorConcept |
152
+ /// MaskedTileIteratorConcept
153
+ ///
154
+ template <typename Shape_, typename Element_, int AdvanceRank,
155
+ typename ThreadMap_, int AccessSize, bool Gather, typename PermuteLayout>
156
+ class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
157
+ ThreadMap_, AccessSize, Gather, PermuteLayout> {
158
+ public:
159
+ static_assert(
160
+ AdvanceRank == 0 || AdvanceRank == 1,
161
+ "Specialization for pitch-linear iterator may advance along the "
162
+ "contiguous(rank=0) or strided(rank=1) dimension.");
163
+
164
+ using Shape = Shape_;
165
+ using Element = Element_;
166
+ using Layout = layout::PitchLinear;
167
+ static int const kAdvanceRank = AdvanceRank;
168
+ using ThreadMap = ThreadMap_;
169
+
170
+ using Index = typename Layout::Index;
171
+ using LongIndex = typename Layout::LongIndex;
172
+
173
+ using TensorRef = TensorRef<Element, Layout>;
174
+ using TensorView = TensorView<Element, Layout>;
175
+ using TensorCoord = typename Layout::TensorCoord;
176
+
177
+ using Pointer = Element *;
178
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
179
+
180
+ /// Type used for internal memory accesses
181
+ using AccessType = AlignedArray<Element, AccessSize, (AccessSize * sizeof_bits<Element>::value / 8)>;
182
+
183
+ /// Underlying iterator to compute the addresses
184
+ using TileAccessIterator =
185
+ PredicatedTileAccessIterator<Shape, Element, Layout, kAdvanceRank,
186
+ ThreadMap, AccessType, Gather, PermuteLayout>;
187
+
188
+ static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
189
+
190
+ /// Fragment object to be loaded or stored
191
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
192
+ ThreadMap::kElementsPerAccess>;
193
+
194
+ /// Predicate vector stores mask to guard accesses
195
+ using Mask = typename TileAccessIterator::Mask;
196
+
197
+ /// Parameters object is precomputed state and is host-constructible
198
+ class Params {
199
+ public:
200
+ using Base = typename TileAccessIterator::Params::Base;
201
+
202
+ friend PredicatedTileIterator;
203
+
204
+ private:
205
+ /// Parameters object
206
+ typename TileAccessIterator::Params params_;
207
+
208
+ public:
209
+ /// Construct the Params object given a pitch-linear tensor's layout
210
+ CUTLASS_HOST_DEVICE
211
+ Params(Layout const &layout) : params_(layout) {}
212
+
213
+ /// Default constructor
214
+ Params() = default;
215
+
216
+ CUTLASS_HOST_DEVICE
217
+ Params(Base const &base)
218
+ : params_(base) {}
219
+ };
220
+
221
+ private:
222
+ /// Internal pointer type permits fast address arithmetic
223
+ using BytePointer = char *;
224
+
225
+ private:
226
+ //
227
+ // Data members
228
+ //
229
+
230
+ /// Data member to the tile access iterator
231
+ TileAccessIterator address_iterator_;
232
+
233
+ public:
234
+
235
+ /// Default constructor
236
+ PredicatedTileIterator() = default;
237
+
238
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
239
+ /// and thread ID
240
+ CUTLASS_HOST_DEVICE
241
+ PredicatedTileIterator(
242
+ /// Precomputed parameters object
243
+ Params const &params,
244
+ /// Pointer to start of tensor
245
+ Pointer pointer,
246
+ /// Extent of tensor
247
+ TensorCoord extent,
248
+ /// ID of each participating thread
249
+ int thread_id,
250
+ /// Initial offset of threadblock
251
+ TensorCoord const &threadblock_offset,
252
+ /// Gather indices
253
+ int const *indices = nullptr)
254
+ : address_iterator_(params.params_, pointer, extent, thread_id,
255
+ threadblock_offset, indices) {}
256
+
257
+ /// Construct a PredicatedTileIterator with zero threadblock offset
258
+ CUTLASS_HOST_DEVICE
259
+ PredicatedTileIterator(
260
+ Params const &params, ///< Precomputed parameters object
261
+ Pointer pointer, ///< Pointer to start of tensor
262
+ TensorCoord extent, ///< Extent of tensor
263
+ int thread_id ///< ID of each participating thread
264
+ )
265
+ : PredicatedTileIterator(params, pointer, extent, thread_id,
266
+ make_Coord(0, 0)) {}
267
+
268
+ /// Adds a pointer offset in units of Element
269
+ CUTLASS_HOST_DEVICE
270
+ void add_pointer_offset(LongIndex pointer_offset) {
271
+ address_iterator_.add_pointer_offset(pointer_offset);
272
+ }
273
+
274
+ /// Advances to the next tile in memory.
275
+ ///
276
+ /// The first time this method is called, predicates are updated, and the
277
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
278
+ /// Subsequent calls are lightweight and must only update the internal
279
+ /// pointer.
280
+ CUTLASS_HOST_DEVICE
281
+ PredicatedTileIterator &operator++() {
282
+ if (kAdvanceRank)
283
+ address_iterator_.add_tile_offset({0, 1});
284
+ else
285
+ address_iterator_.add_tile_offset({1, 0});
286
+
287
+ return *this;
288
+ }
289
+
290
+ /// Advances to the next tile in memory.
291
+ ///
292
+ /// The first time this method is called, predicates are updated, and the
293
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
294
+ /// Subsequent calls are lightweight and must only update the internal
295
+ /// pointer.
296
+ CUTLASS_HOST_DEVICE
297
+ PredicatedTileIterator operator++(int) {
298
+ PredicatedTileIterator self(*this);
299
+ operator++();
300
+ return self;
301
+ }
302
+
303
+ /// Clears the predicate set efficiently
304
+ CUTLASS_HOST_DEVICE
305
+ void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
306
+
307
+ /// Clears the predicate set efficiently
308
+ CUTLASS_HOST_DEVICE
309
+ void enable_mask() { address_iterator_.enable_mask(); }
310
+
311
+ /// Sets the predicate mask, overriding value stored in predicate iterator
312
+ CUTLASS_HOST_DEVICE
313
+ void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); }
314
+
315
+ /// Gets the mask
316
+ CUTLASS_HOST_DEVICE
317
+ void get_mask(Mask &mask) { address_iterator_.get_mask(mask); }
318
+
319
+ CUTLASS_DEVICE
320
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
321
+ load_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
322
+ }
323
+
324
+ CUTLASS_DEVICE
325
+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
326
+
327
+ AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
328
+
329
+ CUTLASS_PRAGMA_UNROLL
330
+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
331
+ CUTLASS_PRAGMA_UNROLL
332
+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
333
+
334
+ CUTLASS_PRAGMA_UNROLL
335
+ for (int v = 0; v < kAccessesPerVector; ++v) {
336
+
337
+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
338
+
339
+ address_iterator_.set_iteration_index(idx);
340
+ char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + byte_offset;
341
+
342
+ AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
343
+
344
+ cutlass::arch::global_load<AccessType,
345
+ sizeof(AccessType)
346
+ >(
347
+ frag_ptr[idx], access_ptr, address_iterator_.valid());
348
+
349
+ ++address_iterator_;
350
+ }
351
+ }
352
+ }
353
+ }
354
+
355
+ /// Loads a fragment from memory
356
+ CUTLASS_DEVICE
357
+ void load(Fragment &frag) { load_with_byte_offset(frag, 0); }
358
+
359
+ /// Store a fragment to memory
360
+ CUTLASS_DEVICE
361
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
362
+ store_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
363
+ }
364
+
365
+ /// Store a fragment to memory
366
+ CUTLASS_DEVICE
367
+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
368
+ address_iterator_.set_iteration_index(0);
369
+ AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
370
+
371
+ CUTLASS_PRAGMA_UNROLL
372
+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
373
+ CUTLASS_PRAGMA_UNROLL
374
+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
375
+ CUTLASS_PRAGMA_UNROLL
376
+ for (int v = 0; v < kAccessesPerVector; ++v) {
377
+
378
+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
379
+
380
+ char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
381
+ AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
382
+
383
+ if (address_iterator_.valid()) {
384
+ *access_ptr = frag_ptr[idx];
385
+ }
386
+ ++address_iterator_;
387
+ }
388
+ }
389
+ }
390
+ }
391
+
392
+ /// Store a fragment to memory
393
+ CUTLASS_DEVICE
394
+ void store(Fragment const &frag) { store_with_byte_offset(frag, 0); }
395
+ };
396
+
397
+ ////////////////////////////////////////////////////////////////////////////////
398
+
399
+ /// Specialization of PredicatedTileIterator for column-major data.
400
+ ///
401
+ /// Satisfies: ForwardTileIteratorConcept |
402
+ /// ReadableContiguousTileIteratorConcept |
403
+ /// WriteableContiguousTileIteratorConcept |
404
+ /// MaskedTileIteratorConcept
405
+ ///
406
+ template <
407
+ typename Shape_,
408
+ typename Element_,
409
+ int AdvanceRank,
410
+ typename ThreadMap_,
411
+ int AccessSize,
412
+ bool Gather,
413
+ typename PermuteLayout
414
+ >
415
+ class PredicatedTileIterator<Shape_, Element_, layout::ColumnMajor, AdvanceRank,
416
+ ThreadMap_, AccessSize, Gather, PermuteLayout> {
417
+ public:
418
+
419
+ static_assert(AdvanceRank == 0 || AdvanceRank == 1,
420
+ "Specialization for pitch-linear iterator may along advance along the "
421
+ "contiguous(rank=0) or strided(rank=1) dimension.");
422
+
423
+ using Shape = Shape_;
424
+ using Element = Element_;
425
+ using Layout = layout::ColumnMajor;
426
+ static int const kAdvanceRank = AdvanceRank;
427
+ using ThreadMap = ThreadMap_;
428
+
429
+ using Index = typename Layout::Index;
430
+ using LongIndex = typename Layout::LongIndex;
431
+
432
+ using TensorRef = TensorRef<Element, Layout>;
433
+ using TensorView = TensorView<Element, Layout>;
434
+ using TensorCoord = typename Layout::TensorCoord;
435
+
436
+ using Pointer = Element *;
437
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
438
+
439
+ using UnderlyingIterator = PredicatedTileIterator<
440
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
441
+ Element,
442
+ layout::PitchLinear,
443
+ (kAdvanceRank == 0 ? 0 : 1),
444
+ ThreadMap,
445
+ AccessSize,
446
+ Gather,
447
+ PermuteLayout
448
+ >;
449
+
450
+ using AccessType = typename UnderlyingIterator::AccessType;
451
+
452
+ /// Fragment object to be loaded or stored
453
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
454
+
455
+ /// Predicate vector stores mask to guard accesses
456
+ using Mask = typename UnderlyingIterator::Mask;
457
+
458
+ /// Parameters object is precomputed state and is host-constructible
459
+ class Params {
460
+ private:
461
+
462
+ friend PredicatedTileIterator;
463
+
464
+ /// Parameters object
465
+ typename UnderlyingIterator::Params params_;
466
+
467
+ public:
468
+
469
+ /// Default constructor
470
+ Params() = default;
471
+
472
+ /// Construct the Params object given a pitch-linear tensor's layout
473
+ CUTLASS_HOST_DEVICE
474
+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0)))
475
+ {}
476
+
477
+ CUTLASS_HOST_DEVICE
478
+ Params(typename UnderlyingIterator::Params::Base const &base)
479
+ : params_(base) {}
480
+ };
481
+
482
+
483
+ private:
484
+
485
+ //
486
+ // Data members
487
+ //
488
+
489
+ /// Underlying pitch-linear tile iterator
490
+ UnderlyingIterator iterator_;
491
+
492
+ public:
493
+
494
+ /// Default constructor
495
+ PredicatedTileIterator() = default;
496
+
497
+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
498
+ CUTLASS_HOST_DEVICE
499
+ PredicatedTileIterator(
500
+ Params const &params, ///< Precomputed parameters object
501
+ Pointer pointer, ///< Pointer to start of tensor
502
+ TensorCoord extent, ///< Extent of tensor
503
+ int thread_id, ///< ID of each participating thread
504
+ TensorCoord const &threadblock_offset, ///< Initial offset of threadblock
505
+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
506
+ ):
507
+ iterator_(
508
+ params.params_,
509
+ pointer,
510
+ layout::PitchLinearCoord(extent.row(), extent.column()),
511
+ thread_id,
512
+ layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()),
513
+ indices)
514
+ { }
515
+
516
+ /// Construct a PredicatedTileIterator with zero threadblock offset
517
+ CUTLASS_HOST_DEVICE
518
+ PredicatedTileIterator(
519
+ Params const &params, ///< Precomputed parameters object
520
+ Pointer pointer, ///< Pointer to start of tensor
521
+ TensorCoord extent, ///< Extent of tensor
522
+ int thread_id ///< ID of each participating thread
523
+ ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
524
+
525
+ /// Adds a pointer offset in units of Element
526
+ CUTLASS_HOST_DEVICE
527
+ void add_pointer_offset(LongIndex pointer_offset) {
528
+ iterator_.add_pointer_offset(pointer_offset);
529
+ }
530
+
531
+ /// Advances to the next tile in memory.
532
+ ///
533
+ /// The first time this method is called, predicates are updated, and the iterator's
534
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
535
+ /// are lightweight and must only update the internal pointer.
536
+ CUTLASS_HOST_DEVICE
537
+ PredicatedTileIterator &operator++() {
538
+ ++iterator_;
539
+ return *this;
540
+ }
541
+
542
+ /// Advances to the next tile in memory.
543
+ ///
544
+ /// The first time this method is called, predicates are updated, and the iterator's
545
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
546
+ /// are lightweight and must only update the internal pointer.
547
+ CUTLASS_HOST_DEVICE
548
+ PredicatedTileIterator operator++(int) {
549
+ PredicatedTileIterator self(*this);
550
+ operator++();
551
+ return self;
552
+ }
553
+
554
+ /// Clears the predicate set efficiently
555
+ CUTLASS_HOST_DEVICE
556
+ void clear_mask(bool enable = true) {
557
+ iterator_.clear_mask(enable);
558
+ }
559
+
560
+ /// Clears the predicate set efficiently
561
+ CUTLASS_HOST_DEVICE
562
+ void enable_mask() {
563
+ iterator_.enable_mask();
564
+ }
565
+
566
+ /// Sets the predicate mask, overriding value stored in predicate iterator
567
+ CUTLASS_HOST_DEVICE
568
+ void set_mask(Mask const &mask) {
569
+ iterator_.set_mask(mask);
570
+ }
571
+
572
+ /// Gets the mask
573
+ CUTLASS_HOST_DEVICE
574
+ void get_mask(Mask &mask) {
575
+ iterator_.get_mask(mask);
576
+ }
577
+
578
+ /// Loads a fragment from memory
579
+ CUTLASS_DEVICE
580
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
581
+ iterator_.load_with_pointer_offset(frag, pointer_offset);
582
+ }
583
+
584
+ /// Loads a fragment from memory
585
+ CUTLASS_DEVICE
586
+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
587
+ iterator_.load_with_byte_offset(frag, byte_offset);
588
+ }
589
+
590
+ /// Loads a fragment from memory
591
+ CUTLASS_DEVICE
592
+ void load(Fragment &frag) {
593
+ load_with_pointer_offset(frag, 0);
594
+ }
595
+
596
+ /// Store a fragment to memory
597
+ CUTLASS_DEVICE
598
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
599
+ iterator_.store_with_pointer_offset(frag, pointer_offset);
600
+ }
601
+
602
+ /// Store a fragment to memory
603
+ CUTLASS_DEVICE
604
+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
605
+ iterator_.store_with_byte_offset(frag, byte_offset);
606
+ }
607
+
608
+ /// Store a fragment to memory
609
+ CUTLASS_DEVICE
610
+ void store(Fragment const &frag) {
611
+ store_with_pointer_offset(frag, 0);
612
+ }
613
+ };
614
+
615
+ ////////////////////////////////////////////////////////////////////////////////
616
+
617
+ /// Specialization of PredicatedTileIterator for row-major data.
618
+ ///
619
+ /// Satisfies: ForwardTileIteratorConcept |
620
+ /// ReadableContiguousTileIteratorConcept |
621
+ /// WriteableContiguousTileIteratorConcept |
622
+ /// MaskedTileIteratorConcept
623
+ ///
624
+ template <
625
+ typename Shape_,
626
+ typename Element_,
627
+ int AdvanceRank,
628
+ typename ThreadMap_,
629
+ int AccessSize,
630
+ bool Gather,
631
+ typename PermuteLayout
632
+ >
633
+ class PredicatedTileIterator<Shape_, Element_, layout::RowMajor, AdvanceRank,
634
+ ThreadMap_, AccessSize, Gather, PermuteLayout> {
635
+ public:
636
+
637
+ static_assert(AdvanceRank == 0 || AdvanceRank == 1,
638
+ "Specialization for pitch-linear iterator may along advance along the "
639
+ "contiguous(rank=0) or strided(rank=1) dimension.");
640
+
641
+ using Shape = Shape_;
642
+ using Element = Element_;
643
+ using Layout = layout::RowMajor;
644
+ static int const kAdvanceRank = AdvanceRank;
645
+ using ThreadMap = ThreadMap_;
646
+
647
+ using Index = typename Layout::Index;
648
+ using LongIndex = typename Layout::LongIndex;
649
+
650
+ using TensorRef = TensorRef<Element, Layout>;
651
+ using TensorView = TensorView<Element, Layout>;
652
+ using TensorCoord = typename Layout::TensorCoord;
653
+
654
+ using Pointer = Element *;
655
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
656
+
657
+ using UnderlyingIterator = PredicatedTileIterator<
658
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
659
+ Element,
660
+ layout::PitchLinear,
661
+ (kAdvanceRank == 0 ? 1 : 0),
662
+ ThreadMap,
663
+ AccessSize,
664
+ Gather,
665
+ PermuteLayout
666
+ >;
667
+
668
+ using AccessType = typename UnderlyingIterator::AccessType;
669
+
670
+ /// Fragment object to be loaded or stored
671
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
672
+
673
+ /// Predicate vector stores mask to guard accesses
674
+ using Mask = typename UnderlyingIterator::Mask;
675
+
676
+ /// Parameters object is precomputed state and is host-constructible
677
+ class Params {
678
+ private:
679
+
680
+ friend PredicatedTileIterator;
681
+
682
+ /// Parameters object
683
+ typename UnderlyingIterator::Params params_;
684
+
685
+ public:
686
+
687
+ /// Default constructor
688
+ Params() = default;
689
+
690
+ /// Construct the Params object given a pitch-linear tensor's layout
691
+ CUTLASS_HOST_DEVICE
692
+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {}
693
+
694
+ CUTLASS_HOST_DEVICE
695
+ Params(typename UnderlyingIterator::Params::Base const &base)
696
+ : params_(base) {}
697
+
698
+ };
699
+
700
+ private:
701
+
702
+ //
703
+ // Data members
704
+ //
705
+
706
+ /// Underlying pitch-linear tile iterator
707
+ UnderlyingIterator iterator_;
708
+
709
+ public:
710
+
711
+ /// Default constructor
712
+ PredicatedTileIterator() = default;
713
+
714
+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
715
+ CUTLASS_HOST_DEVICE
716
+ PredicatedTileIterator(
717
+ Params const &params, ///< Precomputed parameters object
718
+ Pointer pointer, ///< Pointer to start of tensor
719
+ TensorCoord extent, ///< Extent of tensor
720
+ int thread_id, ///< ID of each participating thread
721
+ TensorCoord const &threadblock_offset, ///< Initial offset of threadblock
722
+ int const *indices = nullptr ///< Gather indices
723
+ ):
724
+ iterator_(
725
+ params.params_,
726
+ pointer,
727
+ layout::PitchLinearCoord(extent.column(), extent.row()),
728
+ thread_id,
729
+ layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()),
730
+ indices
731
+ ) { }
732
+
733
+ /// Construct a PredicatedTileIterator with zero threadblock offset
734
+ CUTLASS_HOST_DEVICE
735
+ PredicatedTileIterator(
736
+ Params const &params, ///< Precomputed parameters object
737
+ Pointer pointer, ///< Pointer to start of tensor
738
+ TensorCoord extent, ///< Extent of tensor
739
+ int thread_id ///< ID of each participating thread
740
+ ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
741
+
742
+ /// Adds a pointer offset in units of Element
743
+ CUTLASS_HOST_DEVICE
744
+ void add_pointer_offset(LongIndex pointer_offset) {
745
+ iterator_.add_pointer_offset(pointer_offset);
746
+ }
747
+
748
+ /// Advances to the next tile in memory.
749
+ ///
750
+ /// The first time this method is called, predicates are updated, and the iterator's
751
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
752
+ /// are lightweight and must only update the internal pointer.
753
+ CUTLASS_HOST_DEVICE
754
+ PredicatedTileIterator &operator++() {
755
+ ++iterator_;
756
+ return *this;
757
+ }
758
+
759
+ /// Advances to the next tile in memory.
760
+ ///
761
+ /// The first time this method is called, predicates are updated, and the iterator's
762
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
763
+ /// are lightweight and must only update the internal pointer.
764
+ CUTLASS_HOST_DEVICE
765
+ PredicatedTileIterator operator++(int) {
766
+ PredicatedTileIterator self(*this);
767
+ operator++();
768
+ return self;
769
+ }
770
+
771
+ /// Clears the predicate set efficiently
772
+ CUTLASS_HOST_DEVICE
773
+ void clear_mask(bool enable = true) {
774
+ iterator_.clear_mask(enable);
775
+ }
776
+
777
+ /// Clears the predicate set efficiently
778
+ CUTLASS_HOST_DEVICE
779
+ void enable_mask() {
780
+ iterator_.enable_mask();
781
+ }
782
+
783
+ /// Sets the predicate mask, overriding value stored in predicate iterator
784
+ CUTLASS_HOST_DEVICE
785
+ void set_mask(Mask const &mask) {
786
+ iterator_.set_mask(mask);
787
+ }
788
+
789
+ /// Gets the mask
790
+ CUTLASS_HOST_DEVICE
791
+ void get_mask(Mask &mask) {
792
+ iterator_.get_mask(mask);
793
+ }
794
+
795
+ /// Loads a fragment from memory
796
+ CUTLASS_DEVICE
797
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
798
+ iterator_.load_with_pointer_offset(frag, pointer_offset);
799
+ }
800
+
801
+ /// Loads a fragment from memory
802
+ CUTLASS_DEVICE
803
+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
804
+ iterator_.load_with_byte_offset(frag, byte_offset);
805
+ }
806
+
807
+ /// Loads a fragment from memory
808
+ CUTLASS_DEVICE
809
+ void load(Fragment &frag) {
810
+ load_with_pointer_offset(frag, 0);
811
+ }
812
+
813
+ /// Store a fragment to memory
814
+ CUTLASS_DEVICE
815
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
816
+ iterator_.store_with_pointer_offset(frag, pointer_offset);
817
+ }
818
+
819
+ /// Store a fragment to memory
820
+ CUTLASS_DEVICE
821
+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
822
+ iterator_.store_with_byte_offset(frag, byte_offset);
823
+ }
824
+
825
+ /// Store a fragment to memory
826
+ CUTLASS_DEVICE
827
+ void store(Fragment const &frag) {
828
+ store_with_pointer_offset(frag, 0);
829
+ }
830
+ };
831
+
832
+ ////////////////////////////////////////////////////////////////////////////////
833
+
834
+ /// Specialization of PredicatedTileIterator for affine rank-2 data.
835
+ ///
836
+ /// Satisfies: ForwardTileIteratorConcept |
837
+ /// ReadableContiguousTileIteratorConcept |
838
+ /// WriteableContiguousTileIteratorConcept |
839
+ /// MaskedTileIteratorConcept
840
+ ///
841
+ template <typename Shape_, typename Element_, int AdvanceRank,
842
+ typename ThreadMap_, int AccessSize>
843
+ class PredicatedTileIterator<Shape_, Element_, layout::AffineRankN<2>, AdvanceRank,
844
+ ThreadMap_, AccessSize, false> {
845
+ public:
846
+ static_assert(
847
+ AdvanceRank == 0 || AdvanceRank == 1,
848
+ "Specialization for pitch-linear iterator may advance along the "
849
+ "contiguous(rank=0) or strided(rank=1) dimension.");
850
+
851
+ using Shape = Shape_;
852
+ using Element = Element_;
853
+ using Layout = layout::AffineRankN<2>;
854
+ static int const kAdvanceRank = AdvanceRank;
855
+ using ThreadMap = ThreadMap_;
856
+
857
+ using Index = typename Layout::Index;
858
+ using LongIndex = typename Layout::LongIndex;
859
+
860
+ using TensorRef = TensorRef<Element, Layout>;
861
+ using TensorView = TensorView<Element, Layout>;
862
+ using TensorCoord = typename Layout::TensorCoord;
863
+
864
+ using Pointer = Element *;
865
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
866
+
867
+ /// Type used for internal memory accesses
868
+ using AccessType = AlignedArray<Element, AccessSize, (AccessSize * sizeof_bits<Element>::value / 8)>;
869
+
870
+ /// Underlying iterator to compute the addresses
871
+ using TileAccessIterator =
872
+ PredicatedTileAccessIterator<Shape, Element, Layout, kAdvanceRank,
873
+ ThreadMap, AccessType>;
874
+
875
+ static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
876
+
877
+ /// Fragment object to be loaded or stored
878
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
879
+ ThreadMap::kElementsPerAccess>;
880
+
881
+ /// Predicate vector stores mask to guard accesses
882
+ using Mask = typename TileAccessIterator::Mask;
883
+
884
+ /// Parameters object is precomputed state and is host-constructible
885
+ class Params {
886
+ public:
887
+
888
+ friend PredicatedTileIterator;
889
+
890
+ private:
891
+ /// Parameters object
892
+ typename TileAccessIterator::Params params_;
893
+
894
+ public:
895
+ /// Construct the Params object given a pitch-linear tensor's layout
896
+ CUTLASS_HOST_DEVICE
897
+ Params(Layout const &layout) : params_(layout) {}
898
+
899
+ /// Default constructor
900
+ Params() = default;
901
+ };
902
+
903
+ private:
904
+ /// Internal pointer type permits fast address arithmetic
905
+ using BytePointer = char *;
906
+
907
+ private:
908
+ //
909
+ // Data members
910
+ //
911
+
912
+ /// Data member to the tile access iterator
913
+ TileAccessIterator address_iterator_;
914
+
915
+ public:
916
+
917
+ /// Default constructor
918
+ PredicatedTileIterator() = default;
919
+
920
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
921
+ /// and thread ID
922
+ CUTLASS_HOST_DEVICE
923
+ PredicatedTileIterator(
924
+ /// Precomputed parameters object
925
+ Params const &params,
926
+ /// Pointer to start of tensor
927
+ Pointer pointer,
928
+ /// Extent of tensor
929
+ TensorCoord extent,
930
+ /// ID of each participating thread
931
+ int thread_id,
932
+ /// Initial offset of threadblock
933
+ TensorCoord const &threadblock_offset,
934
+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
935
+ )
936
+ : address_iterator_(params.params_, pointer, extent, thread_id,
937
+ threadblock_offset) {}
938
+
939
+ /// Construct a PredicatedTileIterator with zero threadblock offset
940
+ CUTLASS_HOST_DEVICE
941
+ PredicatedTileIterator(
942
+ Params const &params, ///< Precomputed parameters object
943
+ Pointer pointer, ///< Pointer to start of tensor
944
+ TensorCoord extent, ///< Extent of tensor
945
+ int thread_id ///< ID of each participating thread
946
+ )
947
+ : PredicatedTileIterator(params, pointer, extent, thread_id,
948
+ make_Coord(0, 0)) {}
949
+
950
+ /// Adds a pointer offset in units of Element
951
+ CUTLASS_HOST_DEVICE
952
+ void add_pointer_offset(LongIndex pointer_offset) {
953
+ address_iterator_.add_pointer_offset(pointer_offset);
954
+ }
955
+
956
+ /// Advances to the next tile in memory.
957
+ ///
958
+ /// The first time this method is called, predicates are updated, and the
959
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
960
+ /// Subsequent calls are lightweight and must only update the internal
961
+ /// pointer.
962
+ CUTLASS_HOST_DEVICE
963
+ PredicatedTileIterator &operator++() {
964
+ if (kAdvanceRank)
965
+ address_iterator_.add_tile_offset(make_Coord(0, 1));
966
+ else
967
+ address_iterator_.add_tile_offset(make_Coord(1, 0));
968
+
969
+ return *this;
970
+ }
971
+
972
+ /// Advances to the next tile in memory.
973
+ ///
974
+ /// The first time this method is called, predicates are updated, and the
975
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
976
+ /// Subsequent calls are lightweight and must only update the internal
977
+ /// pointer.
978
+ CUTLASS_HOST_DEVICE
979
+ PredicatedTileIterator operator++(int) {
980
+ PredicatedTileIterator self(*this);
981
+ operator++();
982
+ return self;
983
+ }
984
+
985
+ /// Clears the predicate set efficiently
986
+ CUTLASS_HOST_DEVICE
987
+ void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
988
+
989
+ /// Clears the predicate set efficiently
990
+ CUTLASS_HOST_DEVICE
991
+ void enable_mask() { address_iterator_.enable_mask(); }
992
+
993
+ /// Sets the predicate mask, overriding value stored in predicate iterator
994
+ CUTLASS_HOST_DEVICE
995
+ void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); }
996
+
997
+ /// Gets the mask
998
+ CUTLASS_HOST_DEVICE
999
+ void get_mask(Mask &mask) { address_iterator_.get_mask(mask); }
1000
+
1001
+ CUTLASS_DEVICE
1002
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
1003
+ load_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
1004
+ }
1005
+
1006
+ CUTLASS_DEVICE
1007
+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
1008
+
1009
+ AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
1010
+
1011
+ CUTLASS_PRAGMA_UNROLL
1012
+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
1013
+ CUTLASS_PRAGMA_UNROLL
1014
+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
1015
+
1016
+ CUTLASS_PRAGMA_UNROLL
1017
+ for (int v = 0; v < kAccessesPerVector; ++v) {
1018
+
1019
+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
1020
+
1021
+ address_iterator_.set_iteration_index(idx);
1022
+ char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + byte_offset;
1023
+
1024
+ AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
1025
+
1026
+ cutlass::arch::global_load<AccessType,
1027
+ sizeof(AccessType)
1028
+ >(
1029
+ frag_ptr[idx], access_ptr, address_iterator_.valid());
1030
+
1031
+ ++address_iterator_;
1032
+ }
1033
+ }
1034
+ }
1035
+ }
1036
+
1037
+ /// Loads a fragment from memory
1038
+ CUTLASS_DEVICE
1039
+ void load(Fragment &frag) { load_with_byte_offset(frag, 0); }
1040
+
1041
+ /// Store a fragment to memory
1042
+ CUTLASS_DEVICE
1043
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
1044
+ store_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
1045
+ }
1046
+
1047
+ /// Store a fragment to memory
1048
+ CUTLASS_DEVICE
1049
+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
1050
+ address_iterator_.set_iteration_index(0);
1051
+ AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
1052
+
1053
+ CUTLASS_PRAGMA_UNROLL
1054
+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
1055
+ CUTLASS_PRAGMA_UNROLL
1056
+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
1057
+ CUTLASS_PRAGMA_UNROLL
1058
+ for (int v = 0; v < kAccessesPerVector; ++v) {
1059
+
1060
+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
1061
+
1062
+ char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
1063
+ AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
1064
+
1065
+ if (address_iterator_.valid()) {
1066
+ *access_ptr = frag_ptr[idx];
1067
+ }
1068
+ ++address_iterator_;
1069
+ }
1070
+ }
1071
+ }
1072
+ }
1073
+
1074
+ /// Store a fragment to memory
1075
+ CUTLASS_DEVICE
1076
+ void store(Fragment const &frag) { store_with_byte_offset(frag, 0); }
1077
+ };
1078
+
1079
+ ////////////////////////////////////////////////////////////////////////////////
1080
+
1081
+ /// Specialization of PredicatedTileIterator for affine rank 2 column-major data.
1082
+ ///
1083
+ /// Satisfies: ForwardTileIteratorConcept |
1084
+ /// ReadableContiguousTileIteratorConcept |
1085
+ /// WriteableContiguousTileIteratorConcept |
1086
+ /// MaskedTileIteratorConcept
1087
+ ///
1088
+ template <
1089
+ typename Shape_,
1090
+ typename Element_,
1091
+ int AdvanceRank,
1092
+ typename ThreadMap_,
1093
+ int AccessSize
1094
+ >
1095
+ class PredicatedTileIterator<Shape_, Element_, layout::AffineRank2ColumnMajor, AdvanceRank, ThreadMap_, AccessSize, false> {
1096
+ public:
1097
+
1098
+ static_assert(AdvanceRank == 0 || AdvanceRank == 1,
1099
+ "Specialization for pitch-linear iterator may along advance along the "
1100
+ "contiguous(rank=0) or strided(rank=1) dimension.");
1101
+
1102
+ using Shape = Shape_;
1103
+ using Element = Element_;
1104
+ using Layout = layout::AffineRank2ColumnMajor;
1105
+ static int const kAdvanceRank = AdvanceRank;
1106
+ using ThreadMap = ThreadMap_;
1107
+
1108
+ using Index = typename Layout::Index;
1109
+ using LongIndex = typename Layout::LongIndex;
1110
+
1111
+ using TensorRef = TensorRef<Element, Layout>;
1112
+ using TensorView = TensorView<Element, Layout>;
1113
+ using TensorCoord = typename Layout::TensorCoord;
1114
+
1115
+ using Pointer = Element *;
1116
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
1117
+
1118
+ // Map to the underlying AffineRankN<2> layout
1119
+ using UnderlyingIterator = PredicatedTileIterator<
1120
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
1121
+ Element,
1122
+ layout::AffineRankN<2>,
1123
+ (kAdvanceRank == 0 ? 0 : 1),
1124
+ ThreadMap,
1125
+ AccessSize
1126
+ >;
1127
+
1128
+ using AccessType = typename UnderlyingIterator::AccessType;
1129
+
1130
+ /// Fragment object to be loaded or stored
1131
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
1132
+
1133
+ /// Predicate vector stores mask to guard accesses
1134
+ using Mask = typename UnderlyingIterator::Mask;
1135
+
1136
+ /// Parameters object is precomputed state and is host-constructible
1137
+ class Params {
1138
+ private:
1139
+
1140
+ friend PredicatedTileIterator;
1141
+
1142
+ /// Parameters object
1143
+ typename UnderlyingIterator::Params params_;
1144
+
1145
+ public:
1146
+
1147
+ /// Default constructor
1148
+ Params() = default;
1149
+
1150
+ /// Construct the Params object given an AffineRankN<2> tensor's layout
1151
+ CUTLASS_HOST_DEVICE
1152
+ Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1)))
1153
+ {}
1154
+ };
1155
+
1156
+ private:
1157
+
1158
+ //
1159
+ // Data members
1160
+ //
1161
+
1162
+ /// Underlying AffineRankN<2> tile iterator
1163
+ UnderlyingIterator iterator_;
1164
+
1165
+ public:
1166
+
1167
+ /// Default constructor
1168
+ PredicatedTileIterator() = default;
1169
+
1170
+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
1171
+ CUTLASS_HOST_DEVICE
1172
+ PredicatedTileIterator(
1173
+ Params const &params, ///< Precomputed parameters object
1174
+ Pointer pointer, ///< Pointer to start of tensor
1175
+ TensorCoord extent, ///< Extent of tensor
1176
+ int thread_id, ///< ID of each participating thread
1177
+ TensorCoord const &threadblock_offset, ///< Initial offset of threadblock
1178
+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
1179
+ ):
1180
+ iterator_(
1181
+ params.params_,
1182
+ pointer,
1183
+ layout::PitchLinearCoord(extent.row(), extent.column()),
1184
+ thread_id,
1185
+ layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column())
1186
+ ) { }
1187
+
1188
+ /// Construct a PredicatedTileIterator with zero threadblock offset
1189
+ CUTLASS_HOST_DEVICE
1190
+ PredicatedTileIterator(
1191
+ Params const &params, ///< Precomputed parameters object
1192
+ Pointer pointer, ///< Pointer to start of tensor
1193
+ TensorCoord extent, ///< Extent of tensor
1194
+ int thread_id ///< ID of each participating thread
1195
+ ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
1196
+
1197
+ /// Adds a pointer offset in units of Element
1198
+ CUTLASS_HOST_DEVICE
1199
+ void add_pointer_offset(LongIndex pointer_offset) {
1200
+ iterator_.add_pointer_offset(pointer_offset);
1201
+ }
1202
+
1203
+ /// Advances to the next tile in memory.
1204
+ ///
1205
+ /// The first time this method is called, predicates are updated, and the iterator's
1206
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
1207
+ /// are lightweight and must only update the internal pointer.
1208
+ CUTLASS_HOST_DEVICE
1209
+ PredicatedTileIterator &operator++() {
1210
+ ++iterator_;
1211
+ return *this;
1212
+ }
1213
+
1214
+ /// Advances to the next tile in memory.
1215
+ ///
1216
+ /// The first time this method is called, predicates are updated, and the iterator's
1217
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
1218
+ /// are lightweight and must only update the internal pointer.
1219
+ CUTLASS_HOST_DEVICE
1220
+ PredicatedTileIterator operator++(int) {
1221
+ PredicatedTileIterator self(*this);
1222
+ operator++();
1223
+ return self;
1224
+ }
1225
+
1226
+ /// Clears the predicate set efficiently
1227
+ CUTLASS_HOST_DEVICE
1228
+ void clear_mask(bool enable = true) {
1229
+ iterator_.clear_mask(enable);
1230
+ }
1231
+
1232
+ /// Clears the predicate set efficiently
1233
+ CUTLASS_HOST_DEVICE
1234
+ void enable_mask() {
1235
+ iterator_.enable_mask();
1236
+ }
1237
+
1238
+ /// Sets the predicate mask, overriding value stored in predicate iterator
1239
+ CUTLASS_HOST_DEVICE
1240
+ void set_mask(Mask const &mask) {
1241
+ iterator_.set_mask(mask);
1242
+ }
1243
+
1244
+ /// Gets the mask
1245
+ CUTLASS_HOST_DEVICE
1246
+ void get_mask(Mask &mask) {
1247
+ iterator_.get_mask(mask);
1248
+ }
1249
+
1250
+ /// Loads a fragment from memory
1251
+ CUTLASS_DEVICE
1252
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
1253
+ iterator_.load_with_pointer_offset(frag, pointer_offset);
1254
+ }
1255
+
1256
+ /// Loads a fragment from memory
1257
+ CUTLASS_DEVICE
1258
+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
1259
+ iterator_.load_with_byte_offset(frag, byte_offset);
1260
+ }
1261
+
1262
+ /// Loads a fragment from memory
1263
+ CUTLASS_DEVICE
1264
+ void load(Fragment &frag) {
1265
+ load_with_pointer_offset(frag, 0);
1266
+ }
1267
+
1268
+ /// Store a fragment to memory
1269
+ CUTLASS_DEVICE
1270
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
1271
+ iterator_.store_with_pointer_offset(frag, pointer_offset);
1272
+ }
1273
+
1274
+ /// Store a fragment to memory
1275
+ CUTLASS_DEVICE
1276
+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
1277
+ iterator_.store_with_byte_offset(frag, byte_offset);
1278
+ }
1279
+
1280
+ /// Store a fragment to memory
1281
+ CUTLASS_DEVICE
1282
+ void store(Fragment const &frag) {
1283
+ store_with_pointer_offset(frag, 0);
1284
+ }
1285
+ };
1286
+
1287
+ ////////////////////////////////////////////////////////////////////////////////
1288
+
1289
+ /// Specialization of PredicatedTileIterator for affine rank 2 row-major data.
1290
+ ///
1291
+ /// Satisfies: ForwardTileIteratorConcept |
1292
+ /// ReadableContiguousTileIteratorConcept |
1293
+ /// WriteableContiguousTileIteratorConcept |
1294
+ /// MaskedTileIteratorConcept
1295
+ ///
1296
+ template <
1297
+ typename Shape_,
1298
+ typename Element_,
1299
+ int AdvanceRank,
1300
+ typename ThreadMap_,
1301
+ int AccessSize
1302
+ >
1303
+ class PredicatedTileIterator<Shape_, Element_, layout::AffineRank2RowMajor, AdvanceRank, ThreadMap_, AccessSize, false> {
1304
+ public:
1305
+
1306
+ static_assert(AdvanceRank == 0 || AdvanceRank == 1,
1307
+ "Specialization for pitch-linear iterator may along advance along the "
1308
+ "contiguous(rank=0) or strided(rank=1) dimension.");
1309
+
1310
+ using Shape = Shape_;
1311
+ using Element = Element_;
1312
+ using Layout = layout::AffineRank2RowMajor;
1313
+ static int const kAdvanceRank = AdvanceRank;
1314
+ using ThreadMap = ThreadMap_;
1315
+
1316
+ using Index = typename Layout::Index;
1317
+ using LongIndex = typename Layout::LongIndex;
1318
+
1319
+ using TensorRef = TensorRef<Element, Layout>;
1320
+ using TensorView = TensorView<Element, Layout>;
1321
+ using TensorCoord = typename Layout::TensorCoord;
1322
+
1323
+ using Pointer = Element *;
1324
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
1325
+
1326
+ // Map to the underlying AffineRankN<2> layout
1327
+ using UnderlyingIterator = PredicatedTileIterator<
1328
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
1329
+ Element,
1330
+ layout::AffineRankN<2>,
1331
+ (kAdvanceRank == 0 ? 1 : 0),
1332
+ ThreadMap,
1333
+ AccessSize
1334
+ >;
1335
+
1336
+ using AccessType = typename UnderlyingIterator::AccessType;
1337
+
1338
+ /// Fragment object to be loaded or stored
1339
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
1340
+
1341
+ /// Predicate vector stores mask to guard accesses
1342
+ using Mask = typename UnderlyingIterator::Mask;
1343
+
1344
+ /// Parameters object is precomputed state and is host-constructible
1345
+ class Params {
1346
+ private:
1347
+
1348
+ friend PredicatedTileIterator;
1349
+
1350
+ /// Parameters object
1351
+ typename UnderlyingIterator::Params params_;
1352
+
1353
+ public:
1354
+
1355
+ /// Default constructor
1356
+ Params() = default;
1357
+
1358
+ /// Construct the Params object given an AffineRankN<2> tensor's layout
1359
+ CUTLASS_HOST_DEVICE
1360
+ Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {}
1361
+ };
1362
+
1363
+
1364
+ private:
1365
+
1366
+ //
1367
+ // Data members
1368
+ //
1369
+
1370
+ /// Underlying AffineRankN<2> tile iterator
1371
+ UnderlyingIterator iterator_;
1372
+
1373
+ public:
1374
+
1375
+ /// Default constructor
1376
+ PredicatedTileIterator() = default;
1377
+
1378
+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
1379
+ CUTLASS_HOST_DEVICE
1380
+ PredicatedTileIterator(
1381
+ Params const &params, ///< Precomputed parameters object
1382
+ Pointer pointer, ///< Pointer to start of tensor
1383
+ TensorCoord extent, ///< Extent of tensor
1384
+ int thread_id, ///< ID of each participating thread
1385
+ TensorCoord const &threadblock_offset, ///< Initial offset of threadblock
1386
+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
1387
+ ):
1388
+ iterator_(
1389
+ params.params_,
1390
+ pointer,
1391
+ layout::PitchLinearCoord(extent.column(), extent.row()),
1392
+ thread_id,
1393
+ layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row())
1394
+ ) { }
1395
+
1396
+ /// Construct a PredicatedTileIterator with zero threadblock offset
1397
+ CUTLASS_HOST_DEVICE
1398
+ PredicatedTileIterator(
1399
+ Params const &params, ///< Precomputed parameters object
1400
+ Pointer pointer, ///< Pointer to start of tensor
1401
+ TensorCoord extent, ///< Extent of tensor
1402
+ int thread_id ///< ID of each participating thread
1403
+ ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
1404
+
1405
+ /// Adds a pointer offset in units of Element
1406
+ CUTLASS_HOST_DEVICE
1407
+ void add_pointer_offset(LongIndex pointer_offset) {
1408
+ iterator_.add_pointer_offset(pointer_offset);
1409
+ }
1410
+
1411
+ /// Advances to the next tile in memory.
1412
+ ///
1413
+ /// The first time this method is called, predicates are updated, and the iterator's
1414
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
1415
+ /// are lightweight and must only update the internal pointer.
1416
+ CUTLASS_HOST_DEVICE
1417
+ PredicatedTileIterator &operator++() {
1418
+ ++iterator_;
1419
+ return *this;
1420
+ }
1421
+
1422
+ /// Advances to the next tile in memory.
1423
+ ///
1424
+ /// The first time this method is called, predicates are updated, and the iterator's
1425
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
1426
+ /// are lightweight and must only update the internal pointer.
1427
+ CUTLASS_HOST_DEVICE
1428
+ PredicatedTileIterator operator++(int) {
1429
+ PredicatedTileIterator self(*this);
1430
+ operator++();
1431
+ return self;
1432
+ }
1433
+
1434
+ /// Clears the predicate set efficiently
1435
+ CUTLASS_HOST_DEVICE
1436
+ void clear_mask(bool enable = true) {
1437
+ iterator_.clear_mask(enable);
1438
+ }
1439
+
1440
+ /// Clears the predicate set efficiently
1441
+ CUTLASS_HOST_DEVICE
1442
+ void enable_mask() {
1443
+ iterator_.enable_mask();
1444
+ }
1445
+
1446
+ /// Sets the predicate mask, overriding value stored in predicate iterator
1447
+ CUTLASS_HOST_DEVICE
1448
+ void set_mask(Mask const &mask) {
1449
+ iterator_.set_mask(mask);
1450
+ }
1451
+
1452
+ /// Gets the mask
1453
+ CUTLASS_HOST_DEVICE
1454
+ void get_mask(Mask &mask) {
1455
+ iterator_.get_mask(mask);
1456
+ }
1457
+
1458
+ /// Loads a fragment from memory
1459
+ CUTLASS_DEVICE
1460
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
1461
+ iterator_.load_with_pointer_offset(frag, pointer_offset);
1462
+ }
1463
+
1464
+ /// Loads a fragment from memory
1465
+ CUTLASS_DEVICE
1466
+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
1467
+ iterator_.load_with_byte_offset(frag, byte_offset);
1468
+ }
1469
+
1470
+ /// Loads a fragment from memory
1471
+ CUTLASS_DEVICE
1472
+ void load(Fragment &frag) {
1473
+ load_with_pointer_offset(frag, 0);
1474
+ }
1475
+
1476
+ /// Store a fragment to memory
1477
+ CUTLASS_DEVICE
1478
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
1479
+ iterator_.store_with_pointer_offset(frag, pointer_offset);
1480
+ }
1481
+
1482
+ /// Store a fragment to memory
1483
+ CUTLASS_DEVICE
1484
+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
1485
+ iterator_.store_with_byte_offset(frag, byte_offset);
1486
+ }
1487
+
1488
+ /// Store a fragment to memory
1489
+ CUTLASS_DEVICE
1490
+ void store(Fragment const &frag) {
1491
+ store_with_pointer_offset(frag, 0);
1492
+ }
1493
+ };
1494
+
1495
+ ////////////////////////////////////////////////////////////////////////////////
1496
+
1497
+ /// Specialization of PredicatedTileIterator for interleaved data. It is mapped
1498
+ /// to the congruous layout.
1499
+ ///
1500
+ /// Satisfies: ForwardTileIteratorConcept |
1501
+ /// ReadableContiguousTileIteratorConcept |
1502
+ /// WriteableContiguousTileIteratorConcept |
1503
+ /// MaskedTileIteratorConcept
1504
+ ///
1505
+
1506
+ template <typename Shape_, typename Element_, int AdvanceRank,
1507
+ typename ThreadMap_, int AccessSize, int InterleavedK>
1508
+ class PredicatedTileIterator<Shape_, Element_,
1509
+ layout::ColumnMajorInterleaved<InterleavedK>,
1510
+ AdvanceRank, ThreadMap_, AccessSize, false> {
1511
+ public:
1512
+ static_assert(
1513
+ AdvanceRank == 0 || AdvanceRank == 1,
1514
+ "Specialization for pitch-linear iterator may along advance along the "
1515
+ "contiguous(rank=0) or strided(rank=1) dimension.");
1516
+
1517
+ using Shape = Shape_;
1518
+ using Element = Element_;
1519
+ static int const kInterleavedK = InterleavedK;
1520
+ using Layout = layout::ColumnMajorInterleaved<kInterleavedK>;
1521
+ static int const kAdvanceRank = AdvanceRank;
1522
+ using ThreadMap = ThreadMap_;
1523
+
1524
+ using Index = typename Layout::Index;
1525
+ using LongIndex = typename Layout::LongIndex;
1526
+
1527
+ using TensorRef = TensorRef<Element, Layout>;
1528
+ using TensorView = TensorView<Element, Layout>;
1529
+ using TensorCoord = typename Layout::TensorCoord;
1530
+
1531
+ using Pointer = Element *;
1532
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
1533
+
1534
+ using UnderlyingIterator = PredicatedTileIterator<
1535
+ layout::PitchLinearShape<Shape::kRow * kInterleavedK,
1536
+ Shape::kColumn / kInterleavedK>,
1537
+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessSize>;
1538
+
1539
+
1540
+ using AccessType = typename UnderlyingIterator::AccessType;
1541
+
1542
+ /// Fragment object to be loaded or stored
1543
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
1544
+ ThreadMap::kElementsPerAccess>;
1545
+
1546
+ /// Predicate vector stores mask to guard accesses
1547
+ using Mask = typename UnderlyingIterator::Mask;
1548
+
1549
+ /// Parameters object is precomputed state and is host-constructible
1550
+ class Params {
1551
+ private:
1552
+ friend PredicatedTileIterator;
1553
+
1554
+ /// Parameters object
1555
+ typename UnderlyingIterator::Params params_;
1556
+
1557
+ public:
1558
+
1559
+ /// Default constructor
1560
+ Params() = default;
1561
+
1562
+ /// Construct the Params object given a pitch-linear tensor's layout
1563
+ CUTLASS_HOST_DEVICE
1564
+ Params(Layout const &layout)
1565
+ : params_(layout::PitchLinear(layout.stride(0))) {}
1566
+
1567
+ CUTLASS_HOST_DEVICE
1568
+ Params(typename UnderlyingIterator::Params::Base const &base)
1569
+ : params_(base) {}
1570
+
1571
+ };
1572
+
1573
+ private:
1574
+ //
1575
+ // Data members
1576
+ //
1577
+
1578
+ /// Underlying pitch-linear tile iterator
1579
+ UnderlyingIterator iterator_;
1580
+
1581
+ public:
1582
+
1583
+ /// Default constructor
1584
+ PredicatedTileIterator() = default;
1585
+
1586
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
1587
+ /// and thread ID
1588
+ CUTLASS_HOST_DEVICE
1589
+ PredicatedTileIterator(
1590
+ /// Precomputed parameters object
1591
+ Params const &params,
1592
+ /// Pointer to start of tensor
1593
+ Pointer pointer,
1594
+ /// Extent of tensor
1595
+ TensorCoord extent,
1596
+ /// ID of each participating thread
1597
+ int thread_id,
1598
+ /// Initial offset of threadblock
1599
+ TensorCoord const &threadblock_offset,
1600
+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
1601
+ )
1602
+ : iterator_(params.params_, pointer,
1603
+ layout::PitchLinearCoord(extent.row() * kInterleavedK,
1604
+ extent.column() / kInterleavedK),
1605
+ thread_id,
1606
+ layout::PitchLinearCoord(
1607
+ threadblock_offset.row() * kInterleavedK,
1608
+ threadblock_offset.column() / kInterleavedK)) {}
1609
+
1610
+ /// Construct a PredicatedTileIterator with zero threadblock offset
1611
+ CUTLASS_HOST_DEVICE
1612
+ PredicatedTileIterator(
1613
+ Params const &params, ///< Precomputed parameters object
1614
+ Pointer pointer, ///< Pointer to start of tensor
1615
+ TensorCoord extent, ///< Extent of tensor
1616
+ int thread_id ///< ID of each participating thread
1617
+ )
1618
+ : PredicatedTileIterator(params, pointer, extent, thread_id,
1619
+ make_Coord(0, 0)) {}
1620
+
1621
+ /// Adds a pointer offset in units of Element
1622
+ CUTLASS_HOST_DEVICE
1623
+ void add_pointer_offset(LongIndex pointer_offset) {
1624
+ iterator_.add_pointer_offset(pointer_offset);
1625
+ }
1626
+
1627
+ /// Advances to the next tile in memory.
1628
+ ///
1629
+ /// The first time this method is called, predicates are updated, and the
1630
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1631
+ /// Subsequent calls are lightweight and must only update the internal
1632
+ /// pointer.
1633
+ CUTLASS_HOST_DEVICE
1634
+ PredicatedTileIterator &operator++() {
1635
+ ++iterator_;
1636
+ return *this;
1637
+ }
1638
+
1639
+ /// Advances to the next tile in memory.
1640
+ ///
1641
+ /// The first time this method is called, predicates are updated, and the
1642
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1643
+ /// Subsequent calls are lightweight and must only update the internal
1644
+ /// pointer.
1645
+ CUTLASS_HOST_DEVICE
1646
+ PredicatedTileIterator operator++(int) {
1647
+ PredicatedTileIterator self(*this);
1648
+ operator++();
1649
+ return self;
1650
+ }
1651
+
1652
+ /// Clears the predicate set efficiently
1653
+ CUTLASS_HOST_DEVICE
1654
+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
1655
+
1656
+ /// Clears the predicate set efficiently
1657
+ CUTLASS_HOST_DEVICE
1658
+ void enable_mask() { iterator_.enable_mask(); }
1659
+
1660
+ /// Sets the predicate mask, overriding value stored in predicate iterator
1661
+ CUTLASS_HOST_DEVICE
1662
+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
1663
+
1664
+ /// Gets the mask
1665
+ CUTLASS_HOST_DEVICE
1666
+ void get_mask(Mask &mask) { iterator_.get_mask(mask); }
1667
+
1668
+ /// Loads a fragment from memory
1669
+ CUTLASS_DEVICE
1670
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
1671
+ iterator_.load_with_pointer_offset(frag, pointer_offset);
1672
+ }
1673
+
1674
+ /// Loads a fragment from memory
1675
+ CUTLASS_DEVICE
1676
+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
1677
+
1678
+ /// Store a fragment to memory
1679
+ CUTLASS_DEVICE
1680
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
1681
+ iterator_.store_with_pointer_offset(frag, pointer_offset);
1682
+ }
1683
+
1684
+ /// Store a fragment to memory
1685
+ CUTLASS_DEVICE
1686
+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
1687
+ };
1688
+
1689
+ ////////////////////////////////////////////////////////////////////////////////
1690
+
1691
+ /// Specialization of PredicatedTileIterator for interleaved-32 data. It is
1692
+ /// mapped to the congruous layout.
1693
+ ///
1694
+ /// Satisfies: ForwardTileIteratorConcept |
1695
+ /// ReadableContiguousTileIteratorConcept |
1696
+ /// WriteableContiguousTileIteratorConcept |
1697
+ /// MaskedTileIteratorConcept
1698
+ ///
1699
+ template <typename Shape_, typename Element_, int AdvanceRank,
1700
+ typename ThreadMap_, int AccessSize, int InterleavedK>
1701
+ class PredicatedTileIterator<Shape_, Element_,
1702
+ layout::RowMajorInterleaved<InterleavedK>,
1703
+ AdvanceRank, ThreadMap_, AccessSize, false> {
1704
+ public:
1705
+ static_assert(
1706
+ AdvanceRank == 0 || AdvanceRank == 1,
1707
+ "Specialization for pitch-linear iterator may along advance along the "
1708
+ "contiguous(rank=0) or strided(rank=1) dimension.");
1709
+
1710
+ using Shape = Shape_;
1711
+ using Element = Element_;
1712
+ static int const kInterleavedK = InterleavedK;
1713
+ using Layout = layout::RowMajorInterleaved<kInterleavedK>;
1714
+ static int const kAdvanceRank = AdvanceRank;
1715
+ using ThreadMap = ThreadMap_;
1716
+
1717
+ using Index = typename Layout::Index;
1718
+ using LongIndex = typename Layout::LongIndex;
1719
+
1720
+ using TensorRef = TensorRef<Element, Layout>;
1721
+ using TensorView = TensorView<Element, Layout>;
1722
+ using TensorCoord = typename Layout::TensorCoord;
1723
+
1724
+ using Pointer = Element *;
1725
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
1726
+
1727
+ using UnderlyingIterator = PredicatedTileIterator<
1728
+ layout::PitchLinearShape<Shape::kColumn * kInterleavedK,
1729
+ Shape::kRow / kInterleavedK>,
1730
+ Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessSize>;
1731
+
1732
+
1733
+ using AccessType = typename UnderlyingIterator::AccessType;
1734
+
1735
+ /// Fragment object to be loaded or stored
1736
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
1737
+ ThreadMap::kElementsPerAccess>;
1738
+
1739
+ /// Predicate vector stores mask to guard accesses
1740
+ using Mask = typename UnderlyingIterator::Mask;
1741
+
1742
+ /// Parameters object is precomputed state and is host-constructible
1743
+ class Params {
1744
+ private:
1745
+ friend PredicatedTileIterator;
1746
+
1747
+ /// Parameters object
1748
+ typename UnderlyingIterator::Params params_;
1749
+
1750
+ public:
1751
+
1752
+ /// Default constructor
1753
+ Params() = default;
1754
+
1755
+ /// Construct the Params object given a pitch-linear tensor's layout
1756
+ CUTLASS_HOST_DEVICE
1757
+ Params(Layout const &layout)
1758
+ : params_(layout::PitchLinear(layout.stride(0))) {}
1759
+
1760
+ CUTLASS_HOST_DEVICE
1761
+ Params(typename UnderlyingIterator::Params::Base const &base)
1762
+ : params_(base) {}
1763
+ };
1764
+
1765
+ private:
1766
+ //
1767
+ // Data members
1768
+ //
1769
+
1770
+ /// Underlying pitch-linear tile iterator
1771
+ UnderlyingIterator iterator_;
1772
+
1773
+ public:
1774
+
1775
+ /// Default constructor
1776
+ PredicatedTileIterator() = default;
1777
+
1778
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
1779
+ /// and thread ID
1780
+ CUTLASS_HOST_DEVICE
1781
+ PredicatedTileIterator(
1782
+ /// Precomputed parameters object
1783
+ Params const &params,
1784
+ /// Pointer to start of tensor
1785
+ Pointer pointer,
1786
+ /// Extent of tensor
1787
+ TensorCoord extent,
1788
+ /// ID of each participating thread
1789
+ int thread_id,
1790
+ /// Initial offset of threadblock
1791
+ TensorCoord const &threadblock_offset,
1792
+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
1793
+ )
1794
+ : iterator_(params.params_, pointer,
1795
+ layout::PitchLinearCoord(extent.column() * kInterleavedK,
1796
+ extent.row() / kInterleavedK),
1797
+ thread_id,
1798
+ layout::PitchLinearCoord(
1799
+ threadblock_offset.column() * kInterleavedK,
1800
+ threadblock_offset.row() / kInterleavedK)) {}
1801
+
1802
+ /// Construct a PredicatedTileIterator with zero threadblock offset
1803
+ CUTLASS_HOST_DEVICE
1804
+ PredicatedTileIterator(
1805
+ Params const &params, ///< Precomputed parameters object
1806
+ Pointer pointer, ///< Pointer to start of tensor
1807
+ TensorCoord extent, ///< Extent of tensor
1808
+ int thread_id ///< ID of each participating thread
1809
+ )
1810
+ : PredicatedTileIterator(params, pointer, extent, thread_id,
1811
+ make_Coord(0, 0)) {}
1812
+
1813
+ /// Adds a pointer offset in units of Element
1814
+ CUTLASS_HOST_DEVICE
1815
+ void add_pointer_offset(LongIndex pointer_offset) {
1816
+ iterator_.add_pointer_offset(pointer_offset);
1817
+ }
1818
+
1819
+ /// Advances to the next tile in memory.
1820
+ ///
1821
+ /// The first time this method is called, predicates are updated, and the
1822
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1823
+ /// Subsequent calls are lightweight and must only update the internal
1824
+ /// pointer.
1825
+ CUTLASS_HOST_DEVICE
1826
+ PredicatedTileIterator &operator++() {
1827
+ ++iterator_;
1828
+ return *this;
1829
+ }
1830
+
1831
+ /// Advances to the next tile in memory.
1832
+ ///
1833
+ /// The first time this method is called, predicates are updated, and the
1834
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
1835
+ /// Subsequent calls are lightweight and must only update the internal
1836
+ /// pointer.
1837
+ CUTLASS_HOST_DEVICE
1838
+ PredicatedTileIterator operator++(int) {
1839
+ PredicatedTileIterator self(*this);
1840
+ operator++();
1841
+ return self;
1842
+ }
1843
+
1844
+ /// Clears the predicate set efficiently
1845
+ CUTLASS_HOST_DEVICE
1846
+ void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
1847
+
1848
+ /// Clears the predicate set efficiently
1849
+ CUTLASS_HOST_DEVICE
1850
+ void enable_mask() { iterator_.enable_mask(); }
1851
+
1852
+ /// Sets the predicate mask, overriding value stored in predicate iterator
1853
+ CUTLASS_HOST_DEVICE
1854
+ void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
1855
+
1856
+ /// Gets the mask
1857
+ CUTLASS_HOST_DEVICE
1858
+ void get_mask(Mask &mask) { iterator_.get_mask(mask); }
1859
+
1860
+ /// Loads a fragment from memory
1861
+ CUTLASS_DEVICE
1862
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
1863
+ iterator_.load_with_pointer_offset(frag, pointer_offset);
1864
+ }
1865
+
1866
+ /// Loads a fragment from memory
1867
+ CUTLASS_DEVICE
1868
+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
1869
+
1870
+ /// Store a fragment to memory
1871
+ CUTLASS_DEVICE
1872
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
1873
+ iterator_.store_with_pointer_offset(frag, pointer_offset);
1874
+ }
1875
+
1876
+ /// Store a fragment to memory
1877
+ CUTLASS_DEVICE
1878
+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
1879
+ };
1880
+
1881
+ ////////////////////////////////////////////////////////////////////////////////
1882
+
1883
+ } // namespace threadblock
1884
+ } // namespace transform
1885
+ } // namespace cutlass
1886
+
1887
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h ADDED
@@ -0,0 +1,787 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors.
33
+
34
+ This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile
35
+ first, with the objective of minimizing predicate mask updates during steady-state operation.
36
+
37
+ A precomputed "Params" object minimizes the amount of state that must be stored in registers,
38
+ and integer addition is used to advance the pointer through memory.
39
+ */
40
+
41
+ #pragma once
42
+
43
+ #include "cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h"
44
+ #include "cutlass/transform/thread/transpose.h"
45
+
46
+ ////////////////////////////////////////////////////////////////////////////////
47
+
48
+ namespace cutlass {
49
+ namespace transform {
50
+ namespace threadblock {
51
+
52
+ ////////////////////////////////////////////////////////////////////////////////
53
+
54
+ /// PredicatedTileIterator2dThreadTile
55
+ ///
56
+ /// Satisfies: ForwardTileIteratorConcept |
57
+ /// ReadableContiguousTileIteratorConcept |
58
+ /// WriteableContiguousTileIteratorConcept |
59
+ /// MaskedTileIteratorConcept
60
+ ///
61
+ /// Regular tile iterator using a precomputed control structure to minimize register liveness
62
+ /// and integer arithmetic.
63
+ ///
64
+ /// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed.
65
+ ///
66
+ /// Base pointer and tensor extents may be specified at the time the iterator is constructed.
67
+ /// Subsequently, they are assumed to be immutable.
68
+ ///
69
+ /// Adding a logical coordinate offset may be performed at the time the iterator is constructed.
70
+ /// Subsequent additions to logical coordinate offset may be performed but are relatively expensive.
71
+ ///
72
+ /// Vistitation order is intended to first visit a "residual" tile that may be partially full in
73
+ /// both the advance dimension and the steady-state dimension. This is assumed to be the last
74
+ /// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to
75
+ /// the first tile that is full in the advance dimension and recomputes predicates. Subsequent
76
+ /// accesses may be performed without updating internal predicates and are efficient in terms of
77
+ /// live register state and pointer arithmetic instructions.
78
+ ///
79
+ /// To be efficient, this assumes the iterator will be dereferenced and advanced at least once
80
+ /// outside any looping structure to minimize integer arithmetic.
81
+ ///
82
+ /// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing
83
+ /// the iterator.
84
+ ///
85
+ ///
86
+ /// Example:
87
+ ///
88
+ /// An efficient pipeline structure may be constructed as follows:
89
+ ///
90
+ // template <typename Iterator>
91
+ // __global__ void kernel(
92
+ // typename Iterator::Params params,
93
+ // typename Iterator::Element *ptr,
94
+ // TensorCoord extent) {
95
+ //
96
+ // typename Iterator::Fragment fragment;
97
+ //
98
+ // TensorCoord threadblock_offset(0, 0);
99
+ //
100
+ // Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets);
101
+ //
102
+ //
103
+ // fragment = *iter; // load "residue" tile first
104
+ // ++iter; // advance to first "steady state" tile and update internal masks
105
+ //
106
+ //
107
+ // #pragma unroll
108
+ // for (int i = Remaining - 1; i >= 0; --i) {
109
+ //
110
+ // f(fragment);
111
+ //
112
+ // if (!i) {
113
+ // iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs.
114
+ // }
115
+ //
116
+ // fragment = *iter; // load tile during "steady state" phase
117
+ // ++iter; // advance to next tile - lightweight due to steady-state masks
118
+ // }
119
+ // }
120
+ //
121
+ // void host(TensorView<Element, 2, layout::PitchLinear> view) {
122
+ //
123
+ // using Iterator = transform::threadblock::PredicatedTileIterator2dThreadTile;
124
+ //
125
+ // typename Iterator::Params params(view.layout());
126
+ //
127
+ // kernel<Iterator>(params, view.data());
128
+ // }
129
+ ///
130
+ ///
131
+ template <
132
+ typename Shape,
133
+ typename Element,
134
+ typename Layout,
135
+ int AdvanceRank,
136
+ typename ThreadMap,
137
+ bool Transpose = false
138
+ >
139
+ class PredicatedTileIterator2dThreadTile;
140
+
141
+ ////////////////////////////////////////////////////////////////////////////////
142
+
143
+ /// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data.
144
+ ///
145
+ /// Satisfies: ForwardTileIteratorConcept |
146
+ /// ReadableContiguousTileIteratorConcept |
147
+ /// WriteableContiguousTileIteratorConcept |
148
+ /// MaskedTileIteratorConcept
149
+ ///
150
+ template <typename Shape_, typename Element_, int AdvanceRank, typename ThreadMap_, bool Transpose_>
151
+ class PredicatedTileIterator2dThreadTile<Shape_, Element_, layout::PitchLinear, AdvanceRank, ThreadMap_, Transpose_> {
152
+ public:
153
+ static_assert(
154
+ AdvanceRank == 0 || AdvanceRank == 1,
155
+ "Specialization for pitch-linear iterator may along advance along the "
156
+ "contiguous(rank=0) or strided(rank=1) dimension.");
157
+
158
+ using Shape = Shape_;
159
+ using Element = Element_;
160
+ using Layout = layout::PitchLinear;
161
+ static int const kAdvanceRank = AdvanceRank;
162
+ using ThreadMap = ThreadMap_;
163
+
164
+ using Index = typename Layout::Index;
165
+ using LongIndex = typename Layout::LongIndex;
166
+
167
+ using TensorRef = TensorRef<Element, Layout>;
168
+ using TensorView = TensorView<Element, Layout>;
169
+ using TensorCoord = typename Layout::TensorCoord;
170
+
171
+ using Pointer = Element *;
172
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
173
+
174
+ /// Type used for internal memory accesses
175
+ /// extra set of parenthesis is needed for VS compiler
176
+ struct alignas((ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value /
177
+ 8)) AccessType {
178
+
179
+ Array<Element, ThreadMap::kElementsPerAccess> storage;
180
+
181
+ static int const kElements = ThreadMap::kElementsPerAccess;
182
+ };
183
+
184
+ /// Optionally this fragment can be 4x4 transposed
185
+ using Transform = thread::Transpose< ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount , layout::PitchLinearShape<4,4>, Element>;
186
+ static bool const transpose = Transpose_;
187
+
188
+ /// Underlying iterator to compute the addresses
189
+ using TileAccessIterator =
190
+ PredicatedTileAccessIterator2dThreadTile<Shape, Element, Layout, kAdvanceRank,
191
+ ThreadMap, AccessType>;
192
+
193
+ /// Fragment object to be loaded or stored
194
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
195
+ ThreadMap::ThreadAccessShape::kCount>;
196
+
197
+ /// Predicate vector stores mask to guard accesses
198
+ using Mask = typename TileAccessIterator::Mask;
199
+
200
+ /// Parameters object is precomputed state and is host-constructible
201
+ class Params {
202
+ public:
203
+ using Base = typename TileAccessIterator::Params::Base;
204
+
205
+ friend PredicatedTileIterator2dThreadTile;
206
+
207
+ private:
208
+ /// Parameters object
209
+ typename TileAccessIterator::Params params_;
210
+
211
+ public:
212
+ /// Construct the Params object given a pitch-linear tensor's layout
213
+ CUTLASS_HOST_DEVICE
214
+ Params(Layout const &layout) : params_(layout) { }
215
+
216
+ CUTLASS_HOST_DEVICE
217
+ Params() { }
218
+
219
+ CUTLASS_HOST_DEVICE
220
+ Params(Base const &base)
221
+ : params_(base) {}
222
+ };
223
+
224
+ private:
225
+ /// Internal pointer type permits fast address arithmetic
226
+ using BytePointer = char *;
227
+
228
+ private:
229
+ //
230
+ // Data members
231
+ //
232
+
233
+ /// Data member to the tile access iterator
234
+ TileAccessIterator address_iterator_;
235
+
236
+ public:
237
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
238
+ /// and thread ID
239
+ CUTLASS_HOST_DEVICE
240
+ PredicatedTileIterator2dThreadTile(
241
+ /// Precomputed parameters object
242
+ Params const &params,
243
+ /// Pointer to start of tensor
244
+ Pointer pointer,
245
+ /// Extent of tensor
246
+ TensorCoord extent,
247
+ /// ID of each participating thread
248
+ int thread_id,
249
+ /// Initial offset of threadblock
250
+ TensorCoord const &threadblock_offset,
251
+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
252
+ )
253
+ : address_iterator_(params.params_, pointer, extent, thread_id,
254
+ threadblock_offset) {}
255
+
256
+ /// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset
257
+ CUTLASS_HOST_DEVICE
258
+ PredicatedTileIterator2dThreadTile(
259
+ Params const &params, ///< Precomputed parameters object
260
+ Pointer pointer, ///< Pointer to start of tensor
261
+ TensorCoord extent, ///< Extent of tensor
262
+ int thread_id ///< ID of each participating thread
263
+ )
264
+ : PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id,
265
+ make_Coord(0, 0)) {}
266
+
267
+ /// Adds a pointer offset in units of Element
268
+ CUTLASS_HOST_DEVICE
269
+ void add_pointer_offset(LongIndex pointer_offset) {
270
+ address_iterator_.add_pointer_offset(pointer_offset);
271
+ }
272
+
273
+ /// Advances to the next tile in memory.
274
+ ///
275
+ /// The first time this method is called, predicates are updated, and the
276
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
277
+ /// Subsequent calls are lightweight and must only update the internal
278
+ /// pointer.
279
+ CUTLASS_HOST_DEVICE
280
+ PredicatedTileIterator2dThreadTile &operator++() {
281
+ if (kAdvanceRank)
282
+ address_iterator_.add_tile_offset({0, 1});
283
+ else
284
+ address_iterator_.add_tile_offset({1, 0});
285
+
286
+ return *this;
287
+ }
288
+
289
+ /// Advances to the next tile in memory.
290
+ ///
291
+ /// The first time this method is called, predicates are updated, and the
292
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
293
+ /// Subsequent calls are lightweight and must only update the internal
294
+ /// pointer.
295
+ CUTLASS_HOST_DEVICE
296
+ PredicatedTileIterator2dThreadTile operator++(int) {
297
+ PredicatedTileIterator2dThreadTile self(*this);
298
+ operator++();
299
+ return self;
300
+ }
301
+
302
+ /// Clears the predicate set efficiently
303
+ CUTLASS_HOST_DEVICE
304
+ void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
305
+
306
+ /// Clears the predicate set efficiently
307
+ CUTLASS_HOST_DEVICE
308
+ void enable_mask() { address_iterator_.enable_mask(); }
309
+
310
+ /// Sets the predicate mask, overriding value stored in predicate iterator
311
+ CUTLASS_HOST_DEVICE
312
+ void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); }
313
+
314
+ /// Gets the mask
315
+ CUTLASS_HOST_DEVICE
316
+ void get_mask(Mask &mask) { address_iterator_.get_mask(mask); }
317
+
318
+ /// Loads a fragment from memory
319
+ CUTLASS_DEVICE
320
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
321
+
322
+ AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
323
+
324
+ CUTLASS_PRAGMA_UNROLL
325
+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
326
+ CUTLASS_PRAGMA_UNROLL
327
+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
328
+ CUTLASS_PRAGMA_UNROLL
329
+ for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++){
330
+
331
+ int access_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + \
332
+ s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided;
333
+
334
+ address_iterator_.set_iteration_index(access_idx);
335
+ if (address_iterator_.valid()) {
336
+
337
+ frag_ptr[access_idx] =
338
+ *(address_iterator_.get() + pointer_offset);
339
+ }
340
+
341
+ ++address_iterator_;
342
+ }
343
+ }
344
+ }
345
+
346
+ if (transpose) {
347
+ Transform t;
348
+ t.transform(frag, frag);
349
+ }
350
+ }
351
+
352
+ /// Loads a fragment from memory
353
+ CUTLASS_DEVICE
354
+ void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
355
+
356
+ /// Store a fragment to memory
357
+ CUTLASS_DEVICE
358
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
359
+
360
+ AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
361
+
362
+ CUTLASS_PRAGMA_UNROLL
363
+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
364
+ CUTLASS_PRAGMA_UNROLL
365
+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
366
+ CUTLASS_PRAGMA_UNROLL
367
+ for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++){
368
+
369
+ int access_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + \
370
+ s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided;
371
+
372
+ address_iterator_.set_iteration_index(access_idx);
373
+ if (address_iterator_.valid()) {
374
+ *(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx];
375
+ }
376
+ ++address_iterator_;
377
+ }
378
+ }
379
+ }
380
+ }
381
+
382
+ /// Store a fragment to memory
383
+ CUTLASS_DEVICE
384
+ void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
385
+ };
386
+
387
+ ////////////////////////////////////////////////////////////////////////////////
388
+
389
+ /// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data.
390
+ ///
391
+ /// Satisfies: ForwardTileIteratorConcept |
392
+ /// ReadableContiguousTileIteratorConcept |
393
+ /// WriteableContiguousTileIteratorConcept |
394
+ /// MaskedTileIteratorConcept
395
+ ///
396
+ template <
397
+ typename Shape_,
398
+ typename Element_,
399
+ int AdvanceRank,
400
+ typename ThreadMap_,
401
+ bool Transpose_
402
+ >
403
+ class PredicatedTileIterator2dThreadTile<Shape_, Element_, layout::ColumnMajor, AdvanceRank, ThreadMap_, Transpose_> {
404
+ public:
405
+
406
+ static_assert(AdvanceRank == 0 || AdvanceRank == 1,
407
+ "Specialization for pitch-linear iterator may along advance along the "
408
+ "contiguous(rank=0) or strided(rank=1) dimension.");
409
+
410
+ using Shape = Shape_;
411
+ using Element = Element_;
412
+ using Layout = layout::ColumnMajor;
413
+ static int const kAdvanceRank = AdvanceRank;
414
+ using ThreadMap = ThreadMap_;
415
+ static bool const Transpose = Transpose_;
416
+
417
+ using Index = typename Layout::Index;
418
+ using LongIndex = typename Layout::LongIndex;
419
+
420
+ using TensorRef = TensorRef<Element, Layout>;
421
+ using TensorView = TensorView<Element, Layout>;
422
+ using TensorCoord = typename Layout::TensorCoord;
423
+
424
+ using Pointer = Element *;
425
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
426
+
427
+ using UnderlyingIterator = PredicatedTileIterator2dThreadTile<
428
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
429
+ Element,
430
+ layout::PitchLinear,
431
+ (kAdvanceRank == 0 ? 0 : 1),
432
+ ThreadMap,
433
+ Transpose
434
+ >;
435
+
436
+ using AccessType = typename UnderlyingIterator::AccessType;
437
+
438
+ /// Fragment object to be loaded or stored
439
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount>;
440
+
441
+ /// Predicate vector stores mask to guard accesses
442
+ using Mask = typename UnderlyingIterator::Mask;
443
+
444
+ /// Parameters object is precomputed state and is host-constructible
445
+ class Params {
446
+ private:
447
+
448
+ friend PredicatedTileIterator2dThreadTile;
449
+
450
+ /// Parameters object
451
+ typename UnderlyingIterator::Params params_;
452
+
453
+ public:
454
+
455
+ CUTLASS_HOST_DEVICE
456
+ Params() { }
457
+
458
+ /// Construct the Params object given a pitch-linear tensor's layout
459
+ CUTLASS_HOST_DEVICE
460
+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {}
461
+
462
+ CUTLASS_HOST_DEVICE
463
+ Params(typename UnderlyingIterator::Params::Base const &base)
464
+ : params_(base) {}
465
+ };
466
+
467
+
468
+ private:
469
+
470
+ //
471
+ // Data members
472
+ //
473
+
474
+ /// Underlying pitch-linear tile iterator
475
+ UnderlyingIterator iterator_;
476
+
477
+ public:
478
+
479
+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
480
+ CUTLASS_HOST_DEVICE
481
+ PredicatedTileIterator2dThreadTile(
482
+ Params const &params, ///< Precomputed parameters object
483
+ Pointer pointer, ///< Pointer to start of tensor
484
+ TensorCoord extent, ///< Extent of tensor
485
+ int thread_id, ///< ID of each participating thread
486
+ TensorCoord const &threadblock_offset, ///< Initial offset of threadblock
487
+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
488
+ ):
489
+ iterator_(
490
+ params.params_,
491
+ pointer,
492
+ layout::PitchLinearCoord(extent.row(), extent.column()),
493
+ thread_id,
494
+ layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column())
495
+ ) { }
496
+
497
+ /// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset
498
+ CUTLASS_HOST_DEVICE
499
+ PredicatedTileIterator2dThreadTile(
500
+ Params const &params, ///< Precomputed parameters object
501
+ Pointer pointer, ///< Pointer to start of tensor
502
+ TensorCoord extent, ///< Extent of tensor
503
+ int thread_id ///< ID of each participating thread
504
+ ): PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
505
+
506
+ /// Adds a pointer offset in units of Element
507
+ CUTLASS_HOST_DEVICE
508
+ void add_pointer_offset(LongIndex pointer_offset) {
509
+ iterator_.add_pointer_offset(pointer_offset);
510
+ }
511
+
512
+ /// Advances to the next tile in memory.
513
+ ///
514
+ /// The first time this method is called, predicates are updated, and the iterator's
515
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
516
+ /// are lightweight and must only update the internal pointer.
517
+ CUTLASS_HOST_DEVICE
518
+ PredicatedTileIterator2dThreadTile &operator++() {
519
+ ++iterator_;
520
+ return *this;
521
+ }
522
+
523
+ /// Advances to the next tile in memory.
524
+ ///
525
+ /// The first time this method is called, predicates are updated, and the iterator's
526
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
527
+ /// are lightweight and must only update the internal pointer.
528
+ CUTLASS_HOST_DEVICE
529
+ PredicatedTileIterator2dThreadTile operator++(int) {
530
+ PredicatedTileIterator2dThreadTile self(*this);
531
+ operator++();
532
+ return self;
533
+ }
534
+
535
+ /// Clears the predicate set efficiently
536
+ CUTLASS_HOST_DEVICE
537
+ void clear_mask(bool enable = true) {
538
+ iterator_.clear_mask(enable);
539
+ }
540
+
541
+ /// Clears the predicate set efficiently
542
+ CUTLASS_HOST_DEVICE
543
+ void enable_mask() {
544
+ iterator_.enable_mask();
545
+ }
546
+
547
+ /// Sets the predicate mask, overriding value stored in predicate iterator
548
+ CUTLASS_HOST_DEVICE
549
+ void set_mask(Mask const &mask) {
550
+ iterator_.set_mask(mask);
551
+ }
552
+
553
+ /// Gets the mask
554
+ CUTLASS_HOST_DEVICE
555
+ void get_mask(Mask &mask) {
556
+ iterator_.get_mask(mask);
557
+ }
558
+
559
+ /// Loads a fragment from memory
560
+ CUTLASS_DEVICE
561
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
562
+ iterator_.load_with_pointer_offset(frag, pointer_offset);
563
+ }
564
+
565
+ /// Loads a fragment from memory
566
+ CUTLASS_DEVICE
567
+ void load(Fragment &frag) {
568
+ load_with_pointer_offset(frag, 0);
569
+ }
570
+
571
+ /// Store a fragment to memory
572
+ CUTLASS_DEVICE
573
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
574
+ iterator_.store_with_pointer_offset(frag, pointer_offset);
575
+ }
576
+
577
+ /// Store a fragment to memory
578
+ CUTLASS_DEVICE
579
+ void store(Fragment const &frag) {
580
+ store_with_pointer_offset(frag, 0);
581
+ }
582
+ };
583
+
584
+ ////////////////////////////////////////////////////////////////////////////////
585
+
586
+ /// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data.
587
+ ///
588
+ /// Satisfies: ForwardTileIteratorConcept |
589
+ /// ReadableContiguousTileIteratorConcept |
590
+ /// WriteableContiguousTileIteratorConcept |
591
+ /// MaskedTileIteratorConcept
592
+ ///
593
+ template <
594
+ typename Shape_,
595
+ typename Element_,
596
+ int AdvanceRank,
597
+ typename ThreadMap_,
598
+ bool Transpose_
599
+ >
600
+ class PredicatedTileIterator2dThreadTile<Shape_, Element_, layout::RowMajor, AdvanceRank, ThreadMap_, Transpose_> {
601
+ public:
602
+
603
+ static_assert(AdvanceRank == 0 || AdvanceRank == 1,
604
+ "Specialization for pitch-linear iterator may along advance along the "
605
+ "contiguous(rank=0) or strided(rank=1) dimension.");
606
+
607
+ using Shape = Shape_;
608
+ using Element = Element_;
609
+ using Layout = layout::RowMajor;
610
+ static int const kAdvanceRank = AdvanceRank;
611
+ using ThreadMap = ThreadMap_;
612
+ static bool const Transpose = Transpose_;
613
+
614
+ using Index = typename Layout::Index;
615
+ using LongIndex = typename Layout::LongIndex;
616
+
617
+ using TensorRef = TensorRef<Element, Layout>;
618
+ using TensorView = TensorView<Element, Layout>;
619
+ using TensorCoord = typename Layout::TensorCoord;
620
+
621
+ using Pointer = Element *;
622
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
623
+
624
+ using UnderlyingIterator = PredicatedTileIterator2dThreadTile<
625
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
626
+ Element,
627
+ layout::PitchLinear,
628
+ (kAdvanceRank == 0 ? 1 : 0),
629
+ ThreadMap,
630
+ Transpose
631
+ >;
632
+
633
+ using AccessType = typename UnderlyingIterator::AccessType;
634
+
635
+ /// Fragment object to be loaded or stored
636
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount>;
637
+
638
+ /// Predicate vector stores mask to guard accesses
639
+ using Mask = typename UnderlyingIterator::Mask;
640
+
641
+ /// Parameters object is precomputed state and is host-constructible
642
+ class Params {
643
+ private:
644
+
645
+ friend PredicatedTileIterator2dThreadTile;
646
+
647
+ /// Parameters object
648
+ typename UnderlyingIterator::Params params_;
649
+
650
+ public:
651
+
652
+ CUTLASS_HOST_DEVICE
653
+ Params() { }
654
+
655
+ /// Construct the Params object given a pitch-linear tensor's layout
656
+ CUTLASS_HOST_DEVICE
657
+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { }
658
+
659
+ CUTLASS_HOST_DEVICE
660
+ Params(typename UnderlyingIterator::Params::Base const &base)
661
+ : params_(base) {}
662
+ };
663
+
664
+
665
+ private:
666
+
667
+ //
668
+ // Data members
669
+ //
670
+
671
+ /// Underlying pitch-linear tile iterator
672
+ UnderlyingIterator iterator_;
673
+
674
+ public:
675
+
676
+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
677
+ CUTLASS_HOST_DEVICE
678
+ PredicatedTileIterator2dThreadTile(
679
+ Params const &params, ///< Precomputed parameters object
680
+ Pointer pointer, ///< Pointer to start of tensor
681
+ TensorCoord extent, ///< Extent of tensor
682
+ int thread_id, ///< ID of each participating thread
683
+ TensorCoord const &threadblock_offset, ///< Initial offset of threadblock
684
+ int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
685
+ ):
686
+ iterator_(
687
+ params.params_,
688
+ pointer,
689
+ layout::PitchLinearCoord(extent.column(), extent.row()),
690
+ thread_id,
691
+ layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row())
692
+ ) { }
693
+
694
+ /// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset
695
+ CUTLASS_HOST_DEVICE
696
+ PredicatedTileIterator2dThreadTile(
697
+ Params const &params, ///< Precomputed parameters object
698
+ Pointer pointer, ///< Pointer to start of tensor
699
+ TensorCoord extent, ///< Extent of tensor
700
+ int thread_id ///< ID of each participating thread
701
+ ): PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
702
+
703
+ /// Adds a pointer offset in units of Element
704
+ CUTLASS_HOST_DEVICE
705
+ void add_pointer_offset(LongIndex pointer_offset) {
706
+ iterator_.add_pointer_offset(pointer_offset);
707
+ }
708
+
709
+ /// Advances to the next tile in memory.
710
+ ///
711
+ /// The first time this method is called, predicates are updated, and the iterator's
712
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
713
+ /// are lightweight and must only update the internal pointer.
714
+ CUTLASS_HOST_DEVICE
715
+ PredicatedTileIterator2dThreadTile &operator++() {
716
+ ++iterator_;
717
+ return *this;
718
+ }
719
+
720
+ /// Advances to the next tile in memory.
721
+ ///
722
+ /// The first time this method is called, predicates are updated, and the iterator's
723
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
724
+ /// are lightweight and must only update the internal pointer.
725
+ CUTLASS_HOST_DEVICE
726
+ PredicatedTileIterator2dThreadTile operator++(int) {
727
+ PredicatedTileIterator2dThreadTile self(*this);
728
+ operator++();
729
+ return self;
730
+ }
731
+
732
+ /// Clears the predicate set efficiently
733
+ CUTLASS_HOST_DEVICE
734
+ void clear_mask(bool enable = true) {
735
+ iterator_.clear_mask(enable);
736
+ }
737
+
738
+ /// Clears the predicate set efficiently
739
+ CUTLASS_HOST_DEVICE
740
+ void enable_mask() {
741
+ iterator_.enable_mask();
742
+ }
743
+
744
+ /// Sets the predicate mask, overriding value stored in predicate iterator
745
+ CUTLASS_HOST_DEVICE
746
+ void set_mask(Mask const &mask) {
747
+ iterator_.set_mask(mask);
748
+ }
749
+
750
+ /// Gets the mask
751
+ CUTLASS_HOST_DEVICE
752
+ void get_mask(Mask &mask) {
753
+ iterator_.get_mask(mask);
754
+ }
755
+
756
+ /// Loads a fragment from memory
757
+ CUTLASS_DEVICE
758
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
759
+ iterator_.load_with_pointer_offset(frag, pointer_offset);
760
+ }
761
+
762
+ /// Loads a fragment from memory
763
+ CUTLASS_DEVICE
764
+ void load(Fragment &frag) {
765
+ load_with_pointer_offset(frag, 0);
766
+ }
767
+
768
+ /// Store a fragment to memory
769
+ CUTLASS_DEVICE
770
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
771
+ iterator_.store_with_pointer_offset(frag, pointer_offset);
772
+ }
773
+
774
+ /// Store a fragment to memory
775
+ CUTLASS_DEVICE
776
+ void store(Fragment const &frag) {
777
+ store_with_pointer_offset(frag, 0);
778
+ }
779
+ };
780
+
781
+ ////////////////////////////////////////////////////////////////////////////////
782
+
783
+ } // namespace threadblock
784
+ } // namespace transform
785
+ } // namespace cutlass
786
+
787
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors.
33
+
34
+ This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile
35
+ first, with the objective of minimizing predicate mask updates during steady-state operation.
36
+
37
+ A precomputed "Params" object minimizes the amount of state that must be stored in registers,
38
+ and integer addition is used to advance the pointer through memory.
39
+ */
40
+
41
+ #pragma once
42
+
43
+ #include "cutlass/arch/memory.h"
44
+ #include "cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h"
45
+
46
+ ////////////////////////////////////////////////////////////////////////////////
47
+
48
+ namespace cutlass {
49
+ namespace transform {
50
+ namespace threadblock {
51
+
52
+ ////////////////////////////////////////////////////////////////////////////////
53
+
54
+ /// PredicatedTileIteratorTriangularMatrix
55
+ ///
56
+ /// Satisfies: ForwardTileIteratorConcept |
57
+ /// ReadableContiguousTileIteratorConcept |
58
+ /// WriteableContiguousTileIteratorConcept |
59
+ /// MaskedTileIteratorConcept
60
+ ///
61
+ /// Regular tile iterator using a precomputed control structure to minimize register liveness
62
+ /// and integer arithmetic.
63
+ ///
64
+ /// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed.
65
+ ///
66
+ /// Base pointer and tensor extents may be specified at the time the iterator is constructed.
67
+ /// Subsequently, they are assumed to be immutable.
68
+ ///
69
+ /// Adding a logical coordinate offset may be performed at the time the iterator is constructed.
70
+ /// Subsequent additions to logical coordinate offset may be performed but are relatively expensive.
71
+ ///
72
+ /// Vistitation order is intended to first visit a "residual" tile that may be partially full in
73
+ /// both the advance dimension and the steady-state dimension. This is assumed to be the last
74
+ /// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to
75
+ /// the first tile that is full in the advance dimension and recomputes predicates. Subsequent
76
+ /// accesses may be performed without updating internal predicates and are efficient in terms of
77
+ /// live register state and pointer arithmetic instructions.
78
+ ///
79
+ /// To be efficient, this assumes the iterator will be dereferenced and advanced at least once
80
+ /// outside any looping structure to minimize integer arithmetic.
81
+ ///
82
+ /// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing
83
+ /// the iterator.
84
+ ///
85
+ ///
86
+ /// Example:
87
+ ///
88
+ /// An efficient pipeline structure may be constructed as follows:
89
+ ///
90
+ // template <typename Iterator>
91
+ // __global__ void kernel(
92
+ // typename Iterator::Params params,
93
+ // typename Iterator::Element *ptr,
94
+ // TensorCoord extent) {
95
+ //
96
+ // typename Iterator::Fragment fragment;
97
+ //
98
+ // TensorCoord threadblock_offset(0, 0);
99
+ //
100
+ // Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets);
101
+ //
102
+ //
103
+ // fragment = *iter; // load "residue" tile first
104
+ // ++iter; // advance to first "steady state" tile and update internal masks
105
+ //
106
+ //
107
+ // #pragma unroll
108
+ // for (int i = Remaining - 1; i >= 0; --i) {
109
+ //
110
+ // f(fragment);
111
+ //
112
+ // if (!i) {
113
+ // iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs.
114
+ // }
115
+ //
116
+ // fragment = *iter; // load tile during "steady state" phase
117
+ // ++iter; // advance to next tile - lightweight due to steady-state masks
118
+ // }
119
+ // }
120
+ //
121
+ // void host(TensorView<Element, 2, layout::PitchLinear> view) {
122
+ //
123
+ // using Iterator = transform::threadblock::PredicatedTileIteratorTriangularMatrix;
124
+ //
125
+ // typename Iterator::Params params(view.layout());
126
+ //
127
+ // kernel<Iterator>(params, view.data());
128
+ // }
129
+ ///
130
+ ///
131
+ template <
132
+ typename Shape,
133
+ typename Element,
134
+ typename Layout,
135
+ int AdvanceRank,
136
+ typename ThreadMap,
137
+ SideMode kSideMode,
138
+ FillMode kFillMode,
139
+ DiagType kDiagType,
140
+ int AccessSize = ThreadMap::kElementsPerAccess
141
+ >
142
+ class PredicatedTileIteratorTriangularMatrix;
143
+
144
+ ////////////////////////////////////////////////////////////////////////////////
145
+
146
+ /// Specialization of PredicatedTileIteratorTriangularMatrix for pitch-linear data.
147
+ ///
148
+ /// Satisfies: ForwardTileIteratorConcept |
149
+ /// ReadableContiguousTileIteratorConcept |
150
+ /// WriteableContiguousTileIteratorConcept |
151
+ /// MaskedTileIteratorConcept
152
+ ///
153
+ template <typename Shape_, typename Element_, int AdvanceRank, typename ThreadMap_,
154
+ SideMode kSideMode, FillMode kFillMode, DiagType kDiagType,
155
+ int AccessSize>
156
+ class PredicatedTileIteratorTriangularMatrix<Shape_, Element_, layout::PitchLinear, AdvanceRank, ThreadMap_,
157
+ kSideMode, kFillMode, kDiagType,
158
+ AccessSize> {
159
+ public:
160
+ static_assert(
161
+ AdvanceRank == 0 || AdvanceRank == 1,
162
+ "Specialization for pitch-linear iterator may along advance along the "
163
+ "contiguous(rank=0) or strided(rank=1) dimension.");
164
+
165
+ using Shape = Shape_;
166
+ using Element = Element_;
167
+ using Layout = layout::PitchLinear;
168
+ static int const kAdvanceRank = AdvanceRank;
169
+ using ThreadMap = ThreadMap_;
170
+
171
+ using Index = typename Layout::Index;
172
+ using LongIndex = typename Layout::LongIndex;
173
+
174
+ using TensorRef = TensorRef<Element, Layout>;
175
+ using TensorView = TensorView<Element, Layout>;
176
+ using TensorCoord = typename Layout::TensorCoord;
177
+
178
+ using Pointer = Element *;
179
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
180
+
181
+ /// Type used for internal memory accesses
182
+ using AccessType = AlignedArray<Element, AccessSize, (AccessSize * sizeof_bits<Element>::value / 8)>;
183
+
184
+ /// Underlying iterator to compute the addresses
185
+ using TileAccessIterator =
186
+ PredicatedTileAccessIteratorTriangularMatrix<Shape, Element, Layout, kAdvanceRank,
187
+ ThreadMap, kSideMode, kFillMode, kDiagType, AccessType>;
188
+
189
+ static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
190
+
191
+ /// Fragment object to be loaded or stored
192
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
193
+ ThreadMap::kElementsPerAccess>;
194
+
195
+ /// Predicate vector stores mask to guard accesses
196
+ using Mask = typename TileAccessIterator::Mask;
197
+
198
+ /// Parameters object is precomputed state and is host-constructible
199
+ class Params {
200
+ public:
201
+ friend PredicatedTileIteratorTriangularMatrix;
202
+
203
+ private:
204
+ /// Parameters object
205
+ typename TileAccessIterator::Params params_;
206
+
207
+ public:
208
+ /// Construct the Params object given a pitch-linear tensor's layout
209
+ CUTLASS_HOST_DEVICE
210
+ Params(Layout const &layout) : params_(layout) { }
211
+
212
+ CUTLASS_HOST_DEVICE
213
+ Params() { }
214
+ };
215
+
216
+ private:
217
+ /// Internal pointer type permits fast address arithmetic
218
+ using BytePointer = char *;
219
+
220
+ private:
221
+ //
222
+ // Data members
223
+ //
224
+
225
+ /// Data member to the tile access iterator
226
+ TileAccessIterator address_iterator_;
227
+
228
+ public:
229
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
230
+ /// and thread ID
231
+ CUTLASS_HOST_DEVICE
232
+ PredicatedTileIteratorTriangularMatrix(
233
+ /// Precomputed parameters object
234
+ Params const &params,
235
+ /// Pointer to start of tensor
236
+ Pointer pointer,
237
+ /// Extent of tensor
238
+ TensorCoord extent,
239
+ /// ID of each participating thread
240
+ int thread_id,
241
+ /// Initial offset of threadblock
242
+ TensorCoord const &threadblock_offset)
243
+ : address_iterator_(params.params_, pointer, extent, thread_id,
244
+ threadblock_offset) {}
245
+
246
+ /// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset
247
+ CUTLASS_HOST_DEVICE
248
+ PredicatedTileIteratorTriangularMatrix(
249
+ Params const &params, ///< Precomputed parameters object
250
+ Pointer pointer, ///< Pointer to start of tensor
251
+ TensorCoord extent, ///< Extent of tensor
252
+ int thread_id ///< ID of each participating thread
253
+ )
254
+ : PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id,
255
+ make_Coord(0, 0)) {}
256
+
257
+ /// Adds a pointer offset in units of Element
258
+ CUTLASS_HOST_DEVICE
259
+ void add_pointer_offset(LongIndex pointer_offset) {
260
+ address_iterator_.add_pointer_offset(pointer_offset);
261
+ }
262
+
263
+ /// Advances to the next tile in memory.
264
+ ///
265
+ /// The first time this method is called, predicates are updated, and the
266
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
267
+ /// Subsequent calls are lightweight and must only update the internal
268
+ /// pointer.
269
+ CUTLASS_HOST_DEVICE
270
+ PredicatedTileIteratorTriangularMatrix &operator++() {
271
+ if (kAdvanceRank)
272
+ address_iterator_.add_tile_offset({0, 1});
273
+ else
274
+ address_iterator_.add_tile_offset({1, 0});
275
+
276
+ return *this;
277
+ }
278
+
279
+ /// Advances to the next tile in memory.
280
+ ///
281
+ /// The first time this method is called, predicates are updated, and the
282
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
283
+ /// Subsequent calls are lightweight and must only update the internal
284
+ /// pointer.
285
+ CUTLASS_HOST_DEVICE
286
+ PredicatedTileIteratorTriangularMatrix operator++(int) {
287
+ PredicatedTileIteratorTriangularMatrix self(*this);
288
+ operator++();
289
+ return self;
290
+ }
291
+
292
+ /// Clears the predicate set efficiently
293
+ CUTLASS_HOST_DEVICE
294
+ void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
295
+
296
+ /// Clears the predicate set efficiently
297
+ CUTLASS_HOST_DEVICE
298
+ void enable_mask() { address_iterator_.enable_mask(); }
299
+
300
+ /// Sets the predicate mask, overriding value stored in predicate iterator
301
+ CUTLASS_HOST_DEVICE
302
+ void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); }
303
+
304
+ /// Gets the mask
305
+ CUTLASS_HOST_DEVICE
306
+ void get_mask(Mask &mask) { address_iterator_.get_mask(mask); }
307
+
308
+ CUTLASS_DEVICE
309
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
310
+ load_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
311
+ }
312
+
313
+ CUTLASS_DEVICE
314
+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
315
+
316
+ AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
317
+
318
+ CUTLASS_PRAGMA_UNROLL
319
+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
320
+ CUTLASS_PRAGMA_UNROLL
321
+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
322
+
323
+ CUTLASS_PRAGMA_UNROLL
324
+ for (int v = 0; v < kAccessesPerVector; ++v) {
325
+
326
+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
327
+
328
+ address_iterator_.set_iteration_index(idx);
329
+ char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + byte_offset;
330
+
331
+ AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
332
+
333
+ cutlass::arch::global_load<AccessType,
334
+ sizeof(AccessType)
335
+ >(
336
+ frag_ptr[idx], access_ptr, address_iterator_.valid());
337
+
338
+ ++address_iterator_;
339
+ }
340
+ }
341
+ }
342
+ }
343
+
344
+ /// Loads a fragment from memory
345
+ CUTLASS_DEVICE
346
+ void load(Fragment &frag) { load_with_byte_offset(frag, 0); }
347
+
348
+ /// Store a fragment to memory
349
+ CUTLASS_DEVICE
350
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
351
+ store_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
352
+ }
353
+
354
+ /// Store a fragment to memory
355
+ CUTLASS_DEVICE
356
+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
357
+ address_iterator_.set_iteration_index(0);
358
+ AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
359
+
360
+ CUTLASS_PRAGMA_UNROLL
361
+ for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
362
+ CUTLASS_PRAGMA_UNROLL
363
+ for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
364
+ CUTLASS_PRAGMA_UNROLL
365
+ for (int v = 0; v < kAccessesPerVector; ++v) {
366
+
367
+ int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
368
+
369
+ char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
370
+ AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
371
+
372
+ if (address_iterator_.valid()) {
373
+ *access_ptr = frag_ptr[idx];
374
+ }
375
+ ++address_iterator_;
376
+ }
377
+ }
378
+ }
379
+ }
380
+
381
+ /// Store a fragment to memory
382
+ CUTLASS_DEVICE
383
+ void store(Fragment const &frag) { store_with_byte_offset(frag, 0); }
384
+ };
385
+
386
+ ////////////////////////////////////////////////////////////////////////////////
387
+
388
+ /// Specialization of PredicatedTileIteratorTriangularMatrix for column-major data.
389
+ ///
390
+ /// Satisfies: ForwardTileIteratorConcept |
391
+ /// ReadableContiguousTileIteratorConcept |
392
+ /// WriteableContiguousTileIteratorConcept |
393
+ /// MaskedTileIteratorConcept
394
+ ///
395
+ template <
396
+ typename Shape_,
397
+ typename Element_,
398
+ int AdvanceRank,
399
+ typename ThreadMap_,
400
+ SideMode kSideMode,
401
+ FillMode kFillMode,
402
+ DiagType kDiagType,
403
+ int AccessSize
404
+ >
405
+ class PredicatedTileIteratorTriangularMatrix<Shape_, Element_, layout::ColumnMajor, AdvanceRank, ThreadMap_,
406
+ kSideMode, kFillMode, kDiagType,
407
+ AccessSize> {
408
+ public:
409
+
410
+ static_assert(AdvanceRank == 0 || AdvanceRank == 1,
411
+ "Specialization for pitch-linear iterator may along advance along the "
412
+ "contiguous(rank=0) or strided(rank=1) dimension.");
413
+
414
+ using Shape = Shape_;
415
+ using Element = Element_;
416
+ using Layout = layout::ColumnMajor;
417
+ static int const kAdvanceRank = AdvanceRank;
418
+ using ThreadMap = ThreadMap_;
419
+
420
+ using Index = typename Layout::Index;
421
+ using LongIndex = typename Layout::LongIndex;
422
+
423
+ using TensorRef = TensorRef<Element, Layout>;
424
+ using TensorView = TensorView<Element, Layout>;
425
+ using TensorCoord = typename Layout::TensorCoord;
426
+
427
+ using Pointer = Element *;
428
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
429
+
430
+ using UnderlyingIterator = PredicatedTileIteratorTriangularMatrix<
431
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
432
+ Element,
433
+ layout::PitchLinear,
434
+ (kAdvanceRank == 0 ? 0 : 1),
435
+ ThreadMap,
436
+ kSideMode,
437
+ kFillMode,
438
+ kDiagType,
439
+ AccessSize
440
+ >;
441
+
442
+ using AccessType = typename UnderlyingIterator::AccessType;
443
+
444
+ /// Fragment object to be loaded or stored
445
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
446
+
447
+ /// Predicate vector stores mask to guard accesses
448
+ using Mask = typename UnderlyingIterator::Mask;
449
+
450
+ /// Parameters object is precomputed state and is host-constructible
451
+ class Params {
452
+ private:
453
+
454
+ friend PredicatedTileIteratorTriangularMatrix;
455
+
456
+ /// Parameters object
457
+ typename UnderlyingIterator::Params params_;
458
+
459
+ public:
460
+
461
+ CUTLASS_HOST_DEVICE
462
+ Params() { }
463
+
464
+ /// Construct the Params object given a pitch-linear tensor's layout
465
+ CUTLASS_HOST_DEVICE
466
+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {
467
+
468
+ }
469
+ };
470
+
471
+
472
+ private:
473
+
474
+ //
475
+ // Data members
476
+ //
477
+
478
+ /// Underlying pitch-linear tile iterator
479
+ UnderlyingIterator iterator_;
480
+
481
+ public:
482
+
483
+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
484
+ CUTLASS_HOST_DEVICE
485
+ PredicatedTileIteratorTriangularMatrix(
486
+ Params const &params, ///< Precomputed parameters object
487
+ Pointer pointer, ///< Pointer to start of tensor
488
+ TensorCoord extent, ///< Extent of tensor
489
+ int thread_id, ///< ID of each participating thread
490
+ TensorCoord const &threadblock_offset ///< Initial offset of threadblock
491
+ ):
492
+ iterator_(
493
+ params.params_,
494
+ pointer,
495
+ layout::PitchLinearCoord(extent.row(), extent.column()),
496
+ thread_id,
497
+ layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column())
498
+ ) { }
499
+
500
+ /// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset
501
+ CUTLASS_HOST_DEVICE
502
+ PredicatedTileIteratorTriangularMatrix(
503
+ Params const &params, ///< Precomputed parameters object
504
+ Pointer pointer, ///< Pointer to start of tensor
505
+ TensorCoord extent, ///< Extent of tensor
506
+ int thread_id ///< ID of each participating thread
507
+ ): PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
508
+
509
+ /// Adds a pointer offset in units of Element
510
+ CUTLASS_HOST_DEVICE
511
+ void add_pointer_offset(LongIndex pointer_offset) {
512
+ iterator_.add_pointer_offset(pointer_offset);
513
+ }
514
+
515
+ /// Advances to the next tile in memory.
516
+ ///
517
+ /// The first time this method is called, predicates are updated, and the iterator's
518
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
519
+ /// are lightweight and must only update the internal pointer.
520
+ CUTLASS_HOST_DEVICE
521
+ PredicatedTileIteratorTriangularMatrix &operator++() {
522
+ ++iterator_;
523
+ return *this;
524
+ }
525
+
526
+ /// Advances to the next tile in memory.
527
+ ///
528
+ /// The first time this method is called, predicates are updated, and the iterator's
529
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
530
+ /// are lightweight and must only update the internal pointer.
531
+ CUTLASS_HOST_DEVICE
532
+ PredicatedTileIteratorTriangularMatrix operator++(int) {
533
+ PredicatedTileIteratorTriangularMatrix self(*this);
534
+ operator++();
535
+ return self;
536
+ }
537
+
538
+ /// Clears the predicate set efficiently
539
+ CUTLASS_HOST_DEVICE
540
+ void clear_mask(bool enable = true) {
541
+ iterator_.clear_mask(enable);
542
+ }
543
+
544
+ /// Clears the predicate set efficiently
545
+ CUTLASS_HOST_DEVICE
546
+ void enable_mask() {
547
+ iterator_.enable_mask();
548
+ }
549
+
550
+ /// Sets the predicate mask, overriding value stored in predicate iterator
551
+ CUTLASS_HOST_DEVICE
552
+ void set_mask(Mask const &mask) {
553
+ iterator_.set_mask(mask);
554
+ }
555
+
556
+ /// Gets the mask
557
+ CUTLASS_HOST_DEVICE
558
+ void get_mask(Mask &mask) {
559
+ iterator_.get_mask(mask);
560
+ }
561
+
562
+ /// Loads a fragment from memory
563
+ CUTLASS_DEVICE
564
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
565
+ iterator_.load_with_pointer_offset(frag, pointer_offset);
566
+ }
567
+
568
+ /// Loads a fragment from memory
569
+ CUTLASS_DEVICE
570
+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
571
+ iterator_.load_with_byte_offset(frag, byte_offset);
572
+ }
573
+
574
+ /// Loads a fragment from memory
575
+ CUTLASS_DEVICE
576
+ void load(Fragment &frag) {
577
+ load_with_pointer_offset(frag, 0);
578
+ }
579
+
580
+ /// Store a fragment to memory
581
+ CUTLASS_DEVICE
582
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
583
+ iterator_.store_with_pointer_offset(frag, pointer_offset);
584
+ }
585
+
586
+ /// Store a fragment to memory
587
+ CUTLASS_DEVICE
588
+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
589
+ iterator_.store_with_byte_offset(frag, byte_offset);
590
+ }
591
+
592
+ /// Store a fragment to memory
593
+ CUTLASS_DEVICE
594
+ void store(Fragment const &frag) {
595
+ store_with_pointer_offset(frag, 0);
596
+ }
597
+ };
598
+
599
+ ////////////////////////////////////////////////////////////////////////////////
600
+
601
+ /// Specialization of PredicatedTileIteratorTriangularMatrix for row-major data.
602
+ ///
603
+ /// Satisfies: ForwardTileIteratorConcept |
604
+ /// ReadableContiguousTileIteratorConcept |
605
+ /// WriteableContiguousTileIteratorConcept |
606
+ /// MaskedTileIteratorConcept
607
+ ///
608
+ template <
609
+ typename Shape_,
610
+ typename Element_,
611
+ int AdvanceRank,
612
+ typename ThreadMap_,
613
+ SideMode kSideMode,
614
+ FillMode kFillMode,
615
+ DiagType kDiagType,
616
+ int AccessSize
617
+ >
618
+ class PredicatedTileIteratorTriangularMatrix<Shape_, Element_, layout::RowMajor, AdvanceRank, ThreadMap_,
619
+ kSideMode, kFillMode, kDiagType,
620
+ AccessSize> {
621
+ public:
622
+
623
+ static_assert(AdvanceRank == 0 || AdvanceRank == 1,
624
+ "Specialization for pitch-linear iterator may along advance along the "
625
+ "contiguous(rank=0) or strided(rank=1) dimension.");
626
+
627
+ using Shape = Shape_;
628
+ using Element = Element_;
629
+ using Layout = layout::RowMajor;
630
+ static int const kAdvanceRank = AdvanceRank;
631
+ using ThreadMap = ThreadMap_;
632
+
633
+ using Index = typename Layout::Index;
634
+ using LongIndex = typename Layout::LongIndex;
635
+
636
+ using TensorRef = TensorRef<Element, Layout>;
637
+ using TensorView = TensorView<Element, Layout>;
638
+ using TensorCoord = typename Layout::TensorCoord;
639
+
640
+ using Pointer = Element *;
641
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
642
+
643
+ using UnderlyingIterator = PredicatedTileIteratorTriangularMatrix<
644
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
645
+ Element,
646
+ layout::PitchLinear,
647
+ (kAdvanceRank == 0 ? 1 : 0),
648
+ ThreadMap,
649
+ kSideMode,
650
+ kFillMode,
651
+ kDiagType,
652
+ AccessSize
653
+ >;
654
+
655
+ using AccessType = typename UnderlyingIterator::AccessType;
656
+
657
+ /// Fragment object to be loaded or stored
658
+ using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
659
+
660
+ /// Predicate vector stores mask to guard accesses
661
+ using Mask = typename UnderlyingIterator::Mask;
662
+
663
+ /// Parameters object is precomputed state and is host-constructible
664
+ class Params {
665
+ private:
666
+
667
+ friend PredicatedTileIteratorTriangularMatrix;
668
+
669
+ /// Parameters object
670
+ typename UnderlyingIterator::Params params_;
671
+
672
+ public:
673
+
674
+ CUTLASS_HOST_DEVICE
675
+ Params() { }
676
+
677
+ /// Construct the Params object given a pitch-linear tensor's layout
678
+ CUTLASS_HOST_DEVICE
679
+ Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {
680
+
681
+ };
682
+ };
683
+
684
+
685
+ private:
686
+
687
+ //
688
+ // Data members
689
+ //
690
+
691
+ /// Underlying pitch-linear tile iterator
692
+ UnderlyingIterator iterator_;
693
+
694
+ public:
695
+
696
+ /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
697
+ CUTLASS_HOST_DEVICE
698
+ PredicatedTileIteratorTriangularMatrix(
699
+ Params const &params, ///< Precomputed parameters object
700
+ Pointer pointer, ///< Pointer to start of tensor
701
+ TensorCoord extent, ///< Extent of tensor
702
+ int thread_id, ///< ID of each participating thread
703
+ TensorCoord const &threadblock_offset ///< Initial offset of threadblock
704
+ ):
705
+ iterator_(
706
+ params.params_,
707
+ pointer,
708
+ layout::PitchLinearCoord(extent.column(), extent.row()),
709
+ thread_id,
710
+ layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row())
711
+ ) { }
712
+
713
+ /// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset
714
+ CUTLASS_HOST_DEVICE
715
+ PredicatedTileIteratorTriangularMatrix(
716
+ Params const &params, ///< Precomputed parameters object
717
+ Pointer pointer, ///< Pointer to start of tensor
718
+ TensorCoord extent, ///< Extent of tensor
719
+ int thread_id ///< ID of each participating thread
720
+ ): PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
721
+
722
+ /// Adds a pointer offset in units of Element
723
+ CUTLASS_HOST_DEVICE
724
+ void add_pointer_offset(LongIndex pointer_offset) {
725
+ iterator_.add_pointer_offset(pointer_offset);
726
+ }
727
+
728
+ /// Advances to the next tile in memory.
729
+ ///
730
+ /// The first time this method is called, predicates are updated, and the iterator's
731
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
732
+ /// are lightweight and must only update the internal pointer.
733
+ CUTLASS_HOST_DEVICE
734
+ PredicatedTileIteratorTriangularMatrix &operator++() {
735
+ ++iterator_;
736
+ return *this;
737
+ }
738
+
739
+ /// Advances to the next tile in memory.
740
+ ///
741
+ /// The first time this method is called, predicates are updated, and the iterator's
742
+ /// internal pointer is reverted to the first "steady state" tile. Subsequent calls
743
+ /// are lightweight and must only update the internal pointer.
744
+ CUTLASS_HOST_DEVICE
745
+ PredicatedTileIteratorTriangularMatrix operator++(int) {
746
+ PredicatedTileIteratorTriangularMatrix self(*this);
747
+ operator++();
748
+ return self;
749
+ }
750
+
751
+ /// Clears the predicate set efficiently
752
+ CUTLASS_HOST_DEVICE
753
+ void clear_mask(bool enable = true) {
754
+ iterator_.clear_mask(enable);
755
+ }
756
+
757
+ /// Clears the predicate set efficiently
758
+ CUTLASS_HOST_DEVICE
759
+ void enable_mask() {
760
+ iterator_.enable_mask();
761
+ }
762
+
763
+ /// Sets the predicate mask, overriding value stored in predicate iterator
764
+ CUTLASS_HOST_DEVICE
765
+ void set_mask(Mask const &mask) {
766
+ iterator_.set_mask(mask);
767
+ }
768
+
769
+ /// Gets the mask
770
+ CUTLASS_HOST_DEVICE
771
+ void get_mask(Mask &mask) {
772
+ iterator_.get_mask(mask);
773
+ }
774
+
775
+ /// Loads a fragment from memory
776
+ CUTLASS_DEVICE
777
+ void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
778
+ iterator_.load_with_pointer_offset(frag, pointer_offset);
779
+ }
780
+
781
+ /// Loads a fragment from memory
782
+ CUTLASS_DEVICE
783
+ void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
784
+ iterator_.load_with_byte_offset(frag, byte_offset);
785
+ }
786
+
787
+ /// Loads a fragment from memory
788
+ CUTLASS_DEVICE
789
+ void load(Fragment &frag) {
790
+ load_with_pointer_offset(frag, 0);
791
+ }
792
+
793
+ /// Store a fragment to memory
794
+ CUTLASS_DEVICE
795
+ void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
796
+ iterator_.store_with_pointer_offset(frag, pointer_offset);
797
+ }
798
+
799
+ /// Store a fragment to memory
800
+ CUTLASS_DEVICE
801
+ void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
802
+ iterator_.store_with_byte_offset(frag, byte_offset);
803
+ }
804
+
805
+ /// Store a fragment to memory
806
+ CUTLASS_DEVICE
807
+ void store(Fragment const &frag) {
808
+ store_with_pointer_offset(frag, 0);
809
+ }
810
+ };
811
+
812
+ ////////////////////////////////////////////////////////////////////////////////
813
+
814
+ } // namespace threadblock
815
+ } // namespace transform
816
+ } // namespace cutlass
817
+
818
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Templates implementing computing the addresses of loading small
34
+ vectors from the global memory.
35
+ */
36
+
37
+ #pragma once
38
+
39
+ #include "cutlass/cutlass.h"
40
+ #include "cutlass/array.h"
41
+ #include "cutlass/coord.h"
42
+ #include "cutlass/layout/pitch_linear.h"
43
+ #include "cutlass/layout/matrix.h"
44
+ #include "cutlass/layout/tensor.h"
45
+ #include "cutlass/matrix_coord.h"
46
+ #include "cutlass/matrix_shape.h"
47
+ #include "cutlass/tensor_ref.h"
48
+
49
+ ////////////////////////////////////////////////////////////////////////////////
50
+
51
+ namespace cutlass {
52
+ namespace transform {
53
+ namespace threadblock {
54
+
55
+ ////////////////////////////////////////////////////////////////////////////////
56
+
57
+ /// PredicatedVectorAccessIterator
58
+ ///
59
+ template <
60
+ /// Shape of the vector accessed by the entire threadblock
61
+ typename Shape,
62
+ /// Shape of the vector accessed by the warp
63
+ typename WarpShape,
64
+ /// Type of Element
65
+ typename Element,
66
+ /// Layout of the vector
67
+ typename Layout,
68
+ /// Number of elements for each access
69
+ int ElementsPerAccess,
70
+ /// Support residual tile
71
+ bool EnableResidualAccess = false
72
+ >
73
+ class PredicatedVectorAccessIterator;
74
+
75
+ ////////////////////////////////////////////////////////////////////////////////
76
+
77
+ /// Vector access iterator specialized for vectors, e.g. scale and bias
78
+ /// Thread arrangements are for TensorOps
79
+ ///
80
+ template <
81
+ typename Shape_,
82
+ typename WarpShape_,
83
+ typename Element_,
84
+ int ElementsPerAccess,
85
+ bool EnableResidualAccess
86
+ >
87
+ class PredicatedVectorAccessIterator <
88
+ Shape_,
89
+ WarpShape_,
90
+ Element_,
91
+ layout::PitchLinear,
92
+ ElementsPerAccess,
93
+ EnableResidualAccess
94
+ > {
95
+ public:
96
+
97
+ using Shape = Shape_;
98
+ using WarpShape = WarpShape_;
99
+ using Element = Element_;
100
+ using Layout = layout::PitchLinear;
101
+
102
+ using Index = typename Layout::Index;
103
+ using LongIndex = typename Layout::LongIndex;
104
+
105
+ using TensorRef = TensorRef<Element, Layout>;
106
+ using TensorView = TensorView<Element, Layout>;
107
+ using TensorCoord = typename Layout::TensorCoord;
108
+
109
+ using ConstPointer = const Element *;
110
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
111
+
112
+ // static int const kElementsPerAccess = 128 / sizeof_bits<Element>::value;
113
+ static int const kElementsPerAccess = ElementsPerAccess;
114
+ static int const kThreads = 32;
115
+ static int const kRowsPerIteration = 8;
116
+ static int const kThreadsPerRow = kThreads / kRowsPerIteration;
117
+ static int const kThreadsPerRowMask = 0x3;
118
+ static int const kIterations = WarpShape::kContiguous / (kThreadsPerRow * kElementsPerAccess);
119
+ static int const kWarpCountStrided = Shape::kStrided / WarpShape::kStrided;
120
+
121
+ using AccessType = AlignedArray<Element, kElementsPerAccess>;
122
+
123
+ private:
124
+ /// Internal pointer type permits fast address arithmetic
125
+ using BytePointer = char *;
126
+
127
+ private:
128
+ //
129
+ // Data members
130
+ //
131
+
132
+ /// Internal pointer to first access of tile
133
+ BytePointer pointer_;
134
+
135
+ /// Extent of tensor
136
+ TensorCoord extent_;
137
+
138
+ /// pointer offset of each thread
139
+ TensorCoord thread_offset_;
140
+
141
+ /// iteration index
142
+ LongIndex iteration_;
143
+
144
+ /// residual access
145
+ bool is_residual_;
146
+
147
+ /// residual offset of each thread
148
+ TensorCoord residual_offset_;
149
+
150
+ public:
151
+ /// Constructs a vector access iterator
152
+ CUTLASS_HOST_DEVICE
153
+ PredicatedVectorAccessIterator(
154
+ /// Pointer to the start of the vector
155
+ ConstPointer pointer,
156
+ /// Extent of vector
157
+ TensorCoord extent,
158
+ /// ID of each participating thread
159
+ int thread_id,
160
+ /// ID of each participating warp
161
+ int warp_id,
162
+ /// Initial offset of threadblock
163
+ TensorCoord const &threadblock_offset)
164
+ : pointer_(reinterpret_cast<BytePointer>(
165
+ const_cast<NonConstPointer>(pointer))),
166
+ extent_(extent),
167
+ is_residual_(false) {
168
+
169
+
170
+ int warp_offset = (warp_id / kWarpCountStrided) * WarpShape::kContiguous;
171
+
172
+ // Per-thread offset in logical coordinates of tensor
173
+
174
+ thread_offset_ = threadblock_offset + TensorCoord(warp_offset, 0) +
175
+ TensorCoord((thread_id & kThreadsPerRowMask) * kElementsPerAccess, 0);
176
+
177
+ set_iteration_index(0);
178
+
179
+ if(EnableResidualAccess) {
180
+ // compute residual offset
181
+ typename TensorCoord::Index residual_size = extent_.contiguous() % WarpShape::kContiguous;
182
+ if (residual_size) {
183
+ is_residual_ = true;
184
+ residual_offset_ = make_Coord(residual_size, 0);
185
+ }
186
+ }
187
+ }
188
+
189
+ /// Construct a PredicatedVectorAccessIterator with zero threadblock offset
190
+ CUTLASS_HOST_DEVICE
191
+ PredicatedVectorAccessIterator(
192
+ /// Pointer to start of vector
193
+ ConstPointer pointer,
194
+ /// Extent of vector
195
+ TensorCoord extent,
196
+ ///< ID of each participating thread
197
+ int thread_id,
198
+ /// ID of each participating warp
199
+ int warp_id)
200
+ : PredicatedVectorAccessIterator(pointer, extent, thread_id, warp_id,
201
+ make_Coord(0, 0)) {}
202
+
203
+
204
+ /// Overrides the internal iteration index
205
+ CUTLASS_HOST_DEVICE
206
+ void set_iteration_index(int index) {
207
+ iteration_ = index;
208
+ }
209
+
210
+ /// Advances an iterator along logical dimensions of matrix in units of whole tiles
211
+ CUTLASS_DEVICE
212
+ void add_tile_offset(
213
+ TensorCoord const &tile_offset) {
214
+
215
+ thread_offset_ =
216
+ thread_offset_ +
217
+ TensorCoord(WarpShape::kContiguous * tile_offset.contiguous(), 0);
218
+ }
219
+
220
+ /// Returns a pointer
221
+ CUTLASS_HOST_DEVICE
222
+ AccessType *get() const {
223
+
224
+ return reinterpret_cast<AccessType *>(
225
+ pointer_ +
226
+ ((thread_offset_.contiguous() + iteration_ * kThreadsPerRow * kElementsPerAccess)
227
+ * sizeof_bits<Element>::value / 8));
228
+ }
229
+
230
+ /// Increment and return an instance to self.
231
+ CUTLASS_HOST_DEVICE
232
+ PredicatedVectorAccessIterator &operator++() {
233
+ ++iteration_;
234
+ if(iteration_ >= kIterations)
235
+ iteration_ = 0;
236
+
237
+ return *this;
238
+ }
239
+
240
+ /// Increment and return an instance to self.
241
+ CUTLASS_HOST_DEVICE
242
+ void advance() {
243
+ if(EnableResidualAccess && is_residual_) {
244
+ is_residual_ = false;
245
+ thread_offset_ += residual_offset_;
246
+ }
247
+ else
248
+ add_tile_offset(TensorCoord(1, 0));
249
+ }
250
+
251
+ /// Increment and return an instance to self.
252
+ CUTLASS_HOST_DEVICE
253
+ PredicatedVectorAccessIterator operator++(int) {
254
+ PredicatedVectorAccessIterator self(*this);
255
+ operator++();
256
+ return self;
257
+ }
258
+
259
+ /// Returns whether access is valid or not
260
+ CUTLASS_HOST_DEVICE
261
+ bool valid() {
262
+ return ((thread_offset_.contiguous() +
263
+ iteration_ * kThreadsPerRow * kElementsPerAccess) < extent_.contiguous());
264
+ }
265
+ };
266
+
267
+ ////////////////////////////////////////////////////////////////////////////////
268
+
269
+ /// Specialization of PredicatedVectorAccessIterator for row-major data.
270
+ ///
271
+ template <
272
+ typename Shape_,
273
+ typename WarpShape_,
274
+ typename Element_,
275
+ int ElementsPerAccess,
276
+ bool EnableResidualAccess
277
+ >
278
+ class PredicatedVectorAccessIterator<
279
+ Shape_,
280
+ WarpShape_,
281
+ Element_,
282
+ layout::RowMajor,
283
+ ElementsPerAccess,
284
+ EnableResidualAccess
285
+ > {
286
+ public:
287
+
288
+ using Shape = Shape_;
289
+ using WarpShape = WarpShape_;
290
+ using Element = Element_;
291
+ using Layout = layout::RowMajor;
292
+
293
+ using Index = typename Layout::Index;
294
+ using LongIndex = typename Layout::LongIndex;
295
+
296
+ using TensorRef = TensorRef<Element, Layout>;
297
+ using TensorView = TensorView<Element, Layout>;
298
+ using TensorCoord = typename Layout::TensorCoord;
299
+
300
+ using ConstPointer = const Element *;
301
+ using NonConstPointer = typename platform::remove_const<Element>::type *;
302
+
303
+ using UnderlyingIterator = PredicatedVectorAccessIterator<
304
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
305
+ layout::PitchLinearShape<WarpShape::kColumn, WarpShape::kRow>,
306
+ Element,
307
+ layout::PitchLinear,
308
+ ElementsPerAccess,
309
+ EnableResidualAccess>;
310
+
311
+ using AccessType = typename UnderlyingIterator::AccessType;
312
+ static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess;
313
+ static int const kRowsPerIteration = UnderlyingIterator::kRowsPerIteration;
314
+ static int const kThreads = UnderlyingIterator::kThreads;
315
+ static int const kIterations = UnderlyingIterator::kIterations;
316
+
317
+ private:
318
+ //
319
+ // Data members
320
+ //
321
+
322
+ /// Underlying pitch-linear tile iterator
323
+ UnderlyingIterator iterator_;
324
+
325
+ public:
326
+ /// Constructs a TileIterator from its precomputed state, threadblock offset,
327
+ /// and thread ID
328
+ CUTLASS_HOST_DEVICE
329
+ PredicatedVectorAccessIterator(
330
+ ///< Pointer to the start of the vector
331
+ ConstPointer pointer,
332
+ ///< Extent of tensor
333
+ TensorCoord extent,
334
+ ///< ID of each participating thread
335
+ int thread_id,
336
+ ///< ID of each participating warp
337
+ int warp_id,
338
+ ///< Initial offset of threadblock
339
+ TensorCoord const &threadblock_offset)
340
+ : iterator_(pointer, layout::PitchLinearCoord(extent.column(), extent.row()),
341
+ thread_id, warp_id,
342
+ layout::PitchLinearCoord(threadblock_offset.column(),
343
+ threadblock_offset.row())) {}
344
+
345
+ /// Construct a PredicatedVectorAccessIterator with zero threadblock offset
346
+ CUTLASS_HOST_DEVICE
347
+ PredicatedVectorAccessIterator(
348
+ ConstPointer pointer, ///< Pointer to the start of the vector
349
+ TensorCoord extent, ///< Extent of tensor
350
+ int thread_id, ///< ID of each participating thread
351
+ int warp_id ///< ID of each participating warp
352
+ )
353
+ : PredicatedVectorAccessIterator(pointer, extent, thread_id, warp_id,
354
+ make_Coord(0, 0)) {}
355
+
356
+ /// Overrides the internal iteration index
357
+ CUTLASS_HOST_DEVICE
358
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
359
+
360
+ /// Advances an iterator along logical dimensions of matrix in units of whole
361
+ /// tiles
362
+ CUTLASS_HOST_DEVICE
363
+ void add_tile_offset(TensorCoord const &tile_offset) {
364
+ iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
365
+ }
366
+
367
+ /// Returns a pointer
368
+ CUTLASS_HOST_DEVICE
369
+ AccessType *get() const {
370
+ return reinterpret_cast<AccessType *>(iterator_.get());
371
+ }
372
+
373
+ /// Advances to the next tile in memory.
374
+ ///
375
+ /// The first time this method is called, predicates are updated, and the
376
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
377
+ /// Subsequent calls are lightweight and must only update the internal
378
+ /// pointer.
379
+ CUTLASS_HOST_DEVICE
380
+ PredicatedVectorAccessIterator &operator++() {
381
+ ++iterator_;
382
+ return *this;
383
+ }
384
+
385
+ /// Advances to the next tile in memory.
386
+ ///
387
+ /// The first time this method is called, predicates are updated, and the
388
+ /// iterator's internal pointer is reverted to the first "steady state" tile.
389
+ /// Subsequent calls are lightweight and must only update the internal
390
+ /// pointer.
391
+ CUTLASS_HOST_DEVICE
392
+ PredicatedVectorAccessIterator operator++(int) {
393
+ PredicatedVectorAccessIterator self(*this);
394
+ operator++();
395
+ return self;
396
+ }
397
+
398
+ /// Increment and return an instance to self.
399
+ CUTLASS_HOST_DEVICE
400
+ void advance() {
401
+ iterator_.advance();
402
+ }
403
+
404
+ /// Returns whether access is valid or not
405
+ CUTLASS_HOST_DEVICE
406
+ bool valid() {
407
+ return iterator_.valid();
408
+ }
409
+ };
410
+
411
+
412
+ ////////////////////////////////////////////////////////////////////////////////
413
+
414
+ } // namespace threadblock
415
+ } // namespace transform
416
+ } // namespace cutlass
417
+
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Templates implementing computing the addresses of storing of small
34
+ scale and bias vectors in the shared memory.
35
+ */
36
+
37
+ #pragma once
38
+
39
+ #include "cutlass/cutlass.h"
40
+ #include "cutlass/array.h"
41
+ #include "cutlass/layout/pitch_linear.h"
42
+ #include "cutlass/layout/matrix.h"
43
+ #include "cutlass/matrix_coord.h"
44
+ #include "cutlass/matrix_shape.h"
45
+ #include "cutlass/tensor_ref.h"
46
+
47
+ ////////////////////////////////////////////////////////////////////////////////
48
+
49
+ namespace cutlass {
50
+ namespace transform {
51
+ namespace threadblock {
52
+
53
+ ////////////////////////////////////////////////////////////////////////////////
54
+
55
+ /// RegularScaleBiasVectorAccessIterator
56
+ ///
57
+ template <typename Shape, typename Element, typename Layout>
58
+ class RegularScaleBiasVectorAccessIterator;
59
+
60
+ ////////////////////////////////////////////////////////////////////////////////
61
+
62
+ /// Tile iterator specialized for congruous arrangements for TensorOps
63
+ ///
64
+ ///
65
+ /// Satisfies: ForwardTileIteratorConcept |
66
+ /// ReadableContiguousTileIteratorConcept |
67
+ /// WriteableContiguousTileIteratorConcept
68
+ ///
69
+ template <typename Shape_, typename Element_>
70
+ class RegularScaleBiasVectorAccessIterator<Shape_, Element_, layout::PitchLinear> {
71
+ public:
72
+
73
+ using Shape = Shape_;
74
+ using Element = Element_;
75
+ using Layout = layout::PitchLinear;
76
+
77
+ using Index = typename Layout::Index;
78
+ using LongIndex = typename Layout::LongIndex;
79
+
80
+ using TensorRef = TensorRef<Element, Layout>;
81
+ using TensorCoord = typename Layout::TensorCoord;
82
+
83
+ /// Element type per access
84
+ static int const kElementsPerAccess = 128 / sizeof_bits<Element>::value;
85
+ static int const kThreads = Shape::kContiguous / kElementsPerAccess;
86
+ using AccessType = Array<Element, kElementsPerAccess>;
87
+
88
+ private:
89
+ //
90
+ // Data members
91
+ //
92
+
93
+ /// Internal pointer
94
+ AccessType *pointer_;
95
+
96
+ /// Internal byte offset
97
+ Index byte_offset_;
98
+
99
+ public:
100
+ /// Construct a TileIterator with zero threadblock offset
101
+ CUTLASS_HOST_DEVICE
102
+ RegularScaleBiasVectorAccessIterator(
103
+ TensorRef scale_bias_ref, ///< Pointer to the start of the scale and bias
104
+ ///< vector
105
+ int thread_id ///< ID of each participating thread
106
+ )
107
+ : byte_offset_(0) {
108
+ // Per-thread offset in logical coordinates of tensor
109
+ int thread_offset = thread_id * kElementsPerAccess;
110
+
111
+ // initialize pointer
112
+ pointer_ =
113
+ reinterpret_cast<AccessType *>(scale_bias_ref.data() + thread_offset);
114
+
115
+ set_iteration_index(0);
116
+ }
117
+
118
+ /// Overrides the internal iteration index
119
+ CUTLASS_HOST_DEVICE
120
+ void set_iteration_index(int index) {}
121
+
122
+ /// Adds a pointer offset in units of Element
123
+ CUTLASS_HOST_DEVICE
124
+ void add_pointer_offset(LongIndex pointer_offset) {
125
+ byte_offset_ += pointer_offset * sizeof(Element);
126
+ }
127
+
128
+ /// Returns a pointer
129
+ CUTLASS_DEVICE
130
+ AccessType *get() const {
131
+
132
+ char *access_byte_ptr =
133
+ reinterpret_cast<char *>(pointer_);
134
+
135
+ return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
136
+ }
137
+
138
+ /// Advances to the next tile in memory.
139
+ CUTLASS_HOST_DEVICE
140
+ RegularScaleBiasVectorAccessIterator &operator++() { return *this; }
141
+
142
+ /// Advances to the next tile in memory.
143
+ CUTLASS_HOST_DEVICE
144
+ RegularScaleBiasVectorAccessIterator operator++(int) {
145
+ RegularScaleBiasVectorAccessIterator prev(*this);
146
+ this->operator++();
147
+
148
+ return prev;
149
+ }
150
+
151
+ /// Adds a tile offset in the unit of tile.
152
+ CUTLASS_DEVICE
153
+ void add_tile_offset(TensorCoord const &coord) {
154
+ // Multiply by 2 because we store scale and bias belong to the same stage
155
+ // next to each other.
156
+ add_pointer_offset(coord.contiguous() * Shape::kContiguous * 2);
157
+ }
158
+ };
159
+
160
+ ////////////////////////////////////////////////////////////////////////////////
161
+
162
+ /// Tile iterator specialized for row major layouts
163
+ ///
164
+ ///
165
+ /// Satisfies: ForwardTileIteratorConcept |
166
+ /// ReadableContiguousTileIteratorConcept |
167
+ /// WriteableContiguousTileIteratorConcept
168
+ ///
169
+ template <typename Shape_, typename Element_>
170
+ class RegularScaleBiasVectorAccessIterator<
171
+ Shape_, Element_,
172
+ layout::RowMajor> {
173
+ public:
174
+
175
+ using Shape = Shape_;
176
+ using Element = Element_;
177
+ using Layout = layout::RowMajor;
178
+
179
+ using Index = typename Layout::Index;
180
+ using LongIndex = typename Layout::LongIndex;
181
+
182
+ using TensorRef = TensorRef<Element, Layout>;
183
+ using TensorCoord = typename Layout::TensorCoord;
184
+
185
+ /// Underlying iterator type
186
+ using UnderlyingIterator = RegularScaleBiasVectorAccessIterator<
187
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
188
+ layout::PitchLinear>;
189
+
190
+ using AccessType = typename UnderlyingIterator::AccessType;
191
+
192
+ private:
193
+
194
+ /// Underlying iterator
195
+ UnderlyingIterator iterator_;
196
+
197
+ public:
198
+ /// Construct a TileIterator with zero threadblock offset
199
+ CUTLASS_HOST_DEVICE
200
+ RegularScaleBiasVectorAccessIterator(
201
+ TensorRef scale_bias_ref, ///< Pointer to the start of the scale and bias
202
+ ///< vector
203
+ int thread_id ///< ID of each participating thread
204
+ )
205
+ : iterator_({scale_bias_ref.data(), scale_bias_ref.stride()}, thread_id) {
206
+ }
207
+
208
+ /// Overrides the internal iteration index
209
+ CUTLASS_HOST_DEVICE
210
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
211
+
212
+ /// Adds a pointer offset in units of Element
213
+ CUTLASS_HOST_DEVICE
214
+ void add_pointer_offset(LongIndex pointer_offset) {
215
+ iterator_.add_pointer_offset(pointer_offset);
216
+ }
217
+
218
+ /// Returns a pointer
219
+ CUTLASS_HOST_DEVICE
220
+ AccessType *get() const {
221
+ return reinterpret_cast<AccessType *>(iterator_.get());
222
+ }
223
+
224
+ /// Adds a tile offset
225
+ CUTLASS_DEVICE
226
+ void add_tile_offset(TensorCoord const &coord) {
227
+ iterator_.add_tile_offset({coord.column(), coord.row()});
228
+ }
229
+
230
+ /// Advances to the next tile in memory.
231
+ CUTLASS_HOST_DEVICE
232
+ RegularScaleBiasVectorAccessIterator &operator++() {
233
+ ++iterator_;
234
+ return *this;
235
+ }
236
+
237
+ /// Advances to the next tile in memory.
238
+ CUTLASS_HOST_DEVICE
239
+ RegularScaleBiasVectorAccessIterator operator++(int) {
240
+ RegularScaleBiasVectorAccessIterator prev(*this);
241
+ ++iterator_;
242
+
243
+ return prev;
244
+ }
245
+ };
246
+
247
+ ////////////////////////////////////////////////////////////////////////////////
248
+
249
+ } // namespace threadblock
250
+ } // namespace transform
251
+ } // namespace cutlass
252
+
253
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates implementing the address computation of storing of tiles
33
+ from pitch-linear rank=2 tensors.
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/cutlass.h"
39
+
40
+ ////////////////////////////////////////////////////////////////////////////////
41
+
42
+ namespace cutlass {
43
+ namespace transform {
44
+ namespace threadblock {
45
+
46
+ ////////////////////////////////////////////////////////////////////////////////
47
+
48
+ template <typename Shape, typename Element, typename Layout, int AdvanceRank,
49
+ typename ThreadMap,
50
+ int Alignment =
51
+ sizeof_bits<Element>::value* ThreadMap::kElementsPerAccess / 8>
52
+ class RegularTileAccessIterator;
53
+
54
+ ////////////////////////////////////////////////////////////////////////////////
55
+
56
+ } // namespace threadblock
57
+ } // namespace transform
58
+ } // namespace cutlass
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates implementing computing the addresses of storing of tiles
33
+ from pitch-linear rank=2 tensors.
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/cutlass.h"
39
+ #include "cutlass/array.h"
40
+ #include "cutlass/layout/pitch_linear.h"
41
+ #include "cutlass/layout/matrix.h"
42
+ #include "cutlass/matrix_coord.h"
43
+ #include "cutlass/matrix_shape.h"
44
+ #include "cutlass/tensor_ref.h"
45
+
46
+ #include "cutlass/transform/threadblock/regular_tile_access_iterator.h"
47
+
48
+ ////////////////////////////////////////////////////////////////////////////////
49
+
50
+ namespace cutlass {
51
+ namespace transform {
52
+ namespace threadblock {
53
+
54
+ ////////////////////////////////////////////////////////////////////////////////
55
+
56
+ /// Tile iterator specialized for congruous arrangements for TensorOps
57
+ ///
58
+ ///
59
+ /// Satisfies: ForwardTileIteratorConcept |
60
+ /// ReadableContiguousTileIteratorConcept |
61
+ /// WriteableContiguousTileIteratorConcept
62
+ ///
63
+ template <typename Shape_, typename Element_, int AdvanceRank,
64
+ typename ThreadMap_, int Alignment>
65
+ class RegularTileAccessIterator<
66
+ Shape_, Element_,
67
+ layout::PitchLinear,
68
+ AdvanceRank, ThreadMap_, Alignment> {
69
+ public:
70
+ static_assert(
71
+ AdvanceRank == 0 || AdvanceRank == 1,
72
+ "Specialization for pitch-linear iterator may along advance along the "
73
+ "contiguous(rank=0) or strided(rank=1) dimension.");
74
+
75
+ using Shape = Shape_;
76
+ using Element = Element_;
77
+ using Layout = layout::PitchLinear;
78
+ static int const kAdvanceRank = AdvanceRank;
79
+ static int const kAlignment = Alignment;
80
+
81
+ using Index = typename Layout::Index;
82
+ using LongIndex = typename Layout::LongIndex;
83
+ using StrideIndex = typename Layout::Stride::Index;
84
+
85
+ using TensorRef = TensorRef<Element, Layout>;
86
+ using TensorCoord = typename Layout::TensorCoord;
87
+
88
+ using ThreadMap = ThreadMap_;
89
+
90
+ /// Element type per access
91
+ using AccessType = Array<Element, ThreadMap::kElementsPerAccess>;
92
+
93
+ private:
94
+ //
95
+ // Data members
96
+ //
97
+
98
+ /// Stride value
99
+ StrideIndex stride_;
100
+
101
+ /// Internal pointer to first access of tile
102
+ AccessType *pointer_;
103
+
104
+ /// Internal byte offset
105
+ Index byte_offset_;
106
+
107
+ /// Iteration in the contiguous dimension
108
+ int iteration_contiguous_;
109
+
110
+ /// Iteration in the strided dimension
111
+ int iteration_strided_;
112
+
113
+ public:
114
+ /// Construct a TileIterator with zero threadblock offset
115
+ CUTLASS_HOST_DEVICE
116
+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
117
+ int thread_id ///< ID of each participating thread
118
+ )
119
+ : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess),
120
+ byte_offset_(0) {
121
+
122
+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id);
123
+
124
+ // initialize pointer
125
+ pointer_ = reinterpret_cast<AccessType *>(ref.data() + ref.offset(thread_offset_base));
126
+
127
+ set_iteration_index(0);
128
+ }
129
+
130
+ /// Overrides the internal iteration index
131
+ CUTLASS_HOST_DEVICE
132
+ void set_iteration_index(int index) {
133
+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
134
+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
135
+ }
136
+
137
+ /// Adds a pointer offset in units of Element
138
+ CUTLASS_HOST_DEVICE
139
+ void add_pointer_offset(LongIndex pointer_offset) {
140
+ byte_offset_ += pointer_offset * sizeof(Element);
141
+ }
142
+
143
+ /// Returns a pointer
144
+ CUTLASS_DEVICE
145
+ AccessType *get() const {
146
+
147
+ AccessType *access_ptr = pointer_;
148
+
149
+ int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ +
150
+ iteration_contiguous_ * ThreadMap::Delta::kContiguous /
151
+ ThreadMap::kElementsPerAccess;
152
+
153
+ char *access_byte_ptr =
154
+ reinterpret_cast<char *>(access_ptr + access_offset);
155
+
156
+ return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
157
+ }
158
+
159
+ /// Advances to the next tile in memory.
160
+ CUTLASS_HOST_DEVICE
161
+ RegularTileAccessIterator &operator++() {
162
+ ++iteration_contiguous_;
163
+
164
+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
165
+ return *this;
166
+
167
+ // Enter here only if (iteration_contiguous_ ==
168
+ // ThreadMap::Iteration::kContiguous)
169
+ iteration_contiguous_ = 0;
170
+ ++iteration_strided_;
171
+
172
+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
173
+ return *this;
174
+ }
175
+
176
+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
177
+ // which means we enter the next tile.
178
+ iteration_strided_ = 0;
179
+
180
+ return *this;
181
+ }
182
+
183
+ /// Advances to the next tile in memory.
184
+ CUTLASS_HOST_DEVICE
185
+ RegularTileAccessIterator operator++(int) {
186
+ RegularTileAccessIterator prev(*this);
187
+ this->operator++();
188
+
189
+ return prev;
190
+ }
191
+
192
+ /// Adds a tile offset in the unit of tile.
193
+ /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory.
194
+ /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B.
195
+ /// For row major A operand, k dimension is contiguous dimension;
196
+ /// For col major A operand, k dimension is strided dimension;
197
+ /// For row major B operand, k dimension is strided dimension;
198
+ /// For col major B operand, k dimension is contiguous dimension.
199
+ /// Below two classes map col/row major to the pitch linear coordinates used
200
+ /// in this base class.
201
+ CUTLASS_DEVICE
202
+ void add_tile_offset(TensorCoord const &coord) {
203
+ add_pointer_offset(coord.contiguous() * Shape::kContiguous +
204
+ coord.strided() * Shape::kStrided * stride_ *
205
+ ThreadMap::kElementsPerAccess);
206
+ }
207
+ };
208
+
209
+ ////////////////////////////////////////////////////////////////////////////////
210
+
211
+ /// Tile iterator specialized for column major layouts
212
+ ///
213
+ ///
214
+ /// Satisfies: ForwardTileIteratorConcept |
215
+ /// ReadableContiguousTileIteratorConcept |
216
+ /// WriteableContiguousTileIteratorConcept
217
+ ///
218
+ template <typename Shape_, typename Element_, int AdvanceRank,
219
+ typename ThreadMap_, int Alignment>
220
+ class RegularTileAccessIterator<
221
+ Shape_, Element_,
222
+ layout::ColumnMajor,
223
+ AdvanceRank, ThreadMap_, Alignment> {
224
+ public:
225
+ static_assert(
226
+ AdvanceRank == 0 || AdvanceRank == 1,
227
+ "Specialization for pitch-linear iterator may along advance along the "
228
+ "contiguous(rank=0) or strided(rank=1) dimension.");
229
+
230
+ using Shape = Shape_;
231
+ using Element = Element_;
232
+ using Layout = layout::ColumnMajor;
233
+ static int const kAdvanceRank = AdvanceRank;
234
+ static int const kAlignment = Alignment;
235
+
236
+ using Index = typename Layout::Index;
237
+ using LongIndex = typename Layout::LongIndex;
238
+
239
+ using TensorRef = TensorRef<Element, Layout>;
240
+ using TensorCoord = typename Layout::TensorCoord;
241
+
242
+ using ThreadMap = ThreadMap_;
243
+
244
+ /// Underlying iterator type
245
+ using UnderlyingIterator = RegularTileAccessIterator<
246
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
247
+ layout::PitchLinear,
248
+ (kAdvanceRank == 0 ? 0 : 1),
249
+ ThreadMap_>;
250
+
251
+ using AccessType = typename UnderlyingIterator::AccessType;
252
+
253
+ private:
254
+
255
+ /// Underlying iterator
256
+ UnderlyingIterator iterator_;
257
+
258
+ public:
259
+ /// Construct a TileIterator with zero threadblock offset
260
+ CUTLASS_HOST_DEVICE
261
+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
262
+ int thread_id ///< ID of each participating thread
263
+ )
264
+ : iterator_({ref.data(), ref.stride()}, thread_id) {}
265
+
266
+ /// Overrides the internal iteration index
267
+ CUTLASS_HOST_DEVICE
268
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
269
+
270
+ /// Adds a pointer offset in units of Element
271
+ CUTLASS_HOST_DEVICE
272
+ void add_pointer_offset(LongIndex pointer_offset) {
273
+ iterator_.add_pointer_offset(pointer_offset);
274
+ }
275
+
276
+ /// Returns a pointer
277
+ CUTLASS_HOST_DEVICE
278
+ AccessType *get() const {
279
+ return reinterpret_cast<AccessType *>(iterator_.get());
280
+ }
281
+
282
+ /// Adds a tile offset
283
+ CUTLASS_DEVICE
284
+ void add_tile_offset(TensorCoord const &coord) {
285
+ iterator_.add_tile_offset({coord.row(), coord.column()});
286
+ }
287
+
288
+ /// Advances to the next tile in memory.
289
+ CUTLASS_HOST_DEVICE
290
+ RegularTileAccessIterator &operator++() {
291
+ ++iterator_;
292
+ return *this;
293
+ }
294
+
295
+ /// Advances to the next tile in memory.
296
+ CUTLASS_HOST_DEVICE
297
+ RegularTileAccessIterator operator++(int) {
298
+ RegularTileAccessIterator prev(*this);
299
+ ++iterator_;
300
+
301
+ return prev;
302
+ }
303
+ };
304
+
305
+
306
+ ////////////////////////////////////////////////////////////////////////////////
307
+
308
+ /// Tile iterator specialized for row major layouts
309
+ ///
310
+ ///
311
+ /// Satisfies: ForwardTileIteratorConcept |
312
+ /// ReadableContiguousTileIteratorConcept |
313
+ /// WriteableContiguousTileIteratorConcept
314
+ ///
315
+ template <typename Shape_, typename Element_, int AdvanceRank,
316
+ typename ThreadMap_, int Alignment>
317
+ class RegularTileAccessIterator<
318
+ Shape_, Element_,
319
+ layout::RowMajor,
320
+ AdvanceRank, ThreadMap_, Alignment> {
321
+ public:
322
+ static_assert(
323
+ AdvanceRank == 0 || AdvanceRank == 1,
324
+ "Specialization for pitch-linear iterator may along advance along the "
325
+ "contiguous(rank=0) or strided(rank=1) dimension.");
326
+
327
+ using Shape = Shape_;
328
+ using Element = Element_;
329
+ using Layout = layout::RowMajor;
330
+ static int const kAdvanceRank = AdvanceRank;
331
+ static int const kAlignment = Alignment;
332
+
333
+ using Index = typename Layout::Index;
334
+ using LongIndex = typename Layout::LongIndex;
335
+
336
+ using TensorRef = TensorRef<Element, Layout>;
337
+ using TensorCoord = typename Layout::TensorCoord;
338
+
339
+ using ThreadMap = ThreadMap_;
340
+
341
+ /// Underlying iterator type
342
+ using UnderlyingIterator = RegularTileAccessIterator<
343
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
344
+ layout::PitchLinear,
345
+ (kAdvanceRank == 0 ? 1 : 0),
346
+ ThreadMap_>;
347
+
348
+ using AccessType = typename UnderlyingIterator::AccessType;
349
+
350
+ private:
351
+
352
+ /// Underlying iterator
353
+ UnderlyingIterator iterator_;
354
+
355
+ public:
356
+ /// Construct a TileIterator with zero threadblock offset
357
+ CUTLASS_HOST_DEVICE
358
+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
359
+ int thread_id ///< ID of each participating thread
360
+ )
361
+ : iterator_({ref.data(), ref.stride()}, thread_id) {}
362
+
363
+ /// Overrides the internal iteration index
364
+ CUTLASS_HOST_DEVICE
365
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
366
+
367
+ /// Adds a pointer offset in units of Element
368
+ CUTLASS_HOST_DEVICE
369
+ void add_pointer_offset(LongIndex pointer_offset) {
370
+ iterator_.add_pointer_offset(pointer_offset);
371
+ }
372
+
373
+ /// Returns a pointer
374
+ CUTLASS_HOST_DEVICE
375
+ AccessType *get() const {
376
+ return reinterpret_cast<AccessType *>(iterator_.get());
377
+ }
378
+
379
+ /// Adds a tile offset
380
+ CUTLASS_DEVICE
381
+ void add_tile_offset(TensorCoord const &coord) {
382
+ iterator_.add_tile_offset({coord.column(), coord.row()});
383
+ }
384
+
385
+ /// Advances to the next tile in memory.
386
+ CUTLASS_HOST_DEVICE
387
+ RegularTileAccessIterator &operator++() {
388
+ ++iterator_;
389
+ return *this;
390
+ }
391
+
392
+ /// Advances to the next tile in memory.
393
+ CUTLASS_HOST_DEVICE
394
+ RegularTileAccessIterator operator++(int) {
395
+ RegularTileAccessIterator prev(*this);
396
+ ++iterator_;
397
+
398
+ return prev;
399
+ }
400
+ };
401
+
402
+ ////////////////////////////////////////////////////////////////////////////////
403
+
404
+ } // namespace threadblock
405
+ } // namespace transform
406
+ } // namespace cutlass
407
+
408
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates implementing computing the addresses of storing of tiles
33
+ from pitch-linear rank=2 tensors.
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/cutlass.h"
39
+ #include "cutlass/array.h"
40
+ #include "cutlass/layout/pitch_linear.h"
41
+ #include "cutlass/layout/matrix.h"
42
+ #include "cutlass/matrix_coord.h"
43
+ #include "cutlass/matrix_shape.h"
44
+ #include "cutlass/tensor_ref.h"
45
+
46
+ #include "cutlass/transform/threadblock/regular_tile_access_iterator.h"
47
+
48
+ ////////////////////////////////////////////////////////////////////////////////
49
+
50
+ namespace cutlass {
51
+ namespace transform {
52
+ namespace threadblock {
53
+
54
+
55
+ ////////////////////////////////////////////////////////////////////////////////
56
+
57
+ template <typename Shape, typename Element, typename Layout, int AdvanceRank,
58
+ typename ThreadMap,
59
+ bool Dynamic_iterations = false,
60
+ int Alignment =
61
+ sizeof_bits<Element>::value* ThreadMap::kElementsPerAccess / 8
62
+ >
63
+ class RegularTileAccessIteratorDirectConv;
64
+
65
+ ////////////////////////////////////////////////////////////////////////////////
66
+
67
+ /// Tile iterator specialized for congruous arrangements for TensorOps with dynamic_iterations OFF
68
+ ///
69
+ ///
70
+ /// Satisfies: ForwardTileIteratorConcept |
71
+ /// ReadableContiguousTileIteratorConcept |
72
+ /// WriteableContiguousTileIteratorConcept
73
+ ///
74
+ template <typename Shape_, typename Element_, int AdvanceRank,
75
+ typename ThreadMap_, int Alignment>
76
+ class RegularTileAccessIteratorDirectConv<
77
+ Shape_, Element_,
78
+ layout::PitchLinear,
79
+ AdvanceRank, ThreadMap_, false, Alignment> {
80
+ public:
81
+ static_assert(
82
+ AdvanceRank == 0 || AdvanceRank == 1,
83
+ "Specialization for pitch-linear iterator may along advance along the "
84
+ "contiguous(rank=0) or strided(rank=1) dimension.");
85
+
86
+ using Shape = Shape_;
87
+ using Element = Element_;
88
+ using Layout = layout::PitchLinear;
89
+ static int const kAdvanceRank = AdvanceRank;
90
+ static int const kAlignment = Alignment;
91
+
92
+ using Index = typename Layout::Index;
93
+ using LongIndex = typename Layout::LongIndex;
94
+ using StrideIndex = typename Layout::Stride::Index;
95
+
96
+ using TensorRef = TensorRef<Element, Layout>;
97
+ using TensorCoord = typename Layout::TensorCoord;
98
+
99
+ using ThreadMap = ThreadMap_;
100
+
101
+ /// Element type per access
102
+ using AccessType = Array<Element, ThreadMap::kElementsPerAccess>;
103
+
104
+ private:
105
+ //
106
+ // Data members
107
+ //
108
+
109
+ /// Stride value
110
+ StrideIndex stride_;
111
+
112
+ /// Internal pointer to first access of tile
113
+ AccessType *pointer_;
114
+
115
+ /// Internal byte offset
116
+ Index byte_offset_;
117
+
118
+ /// Iteration in the contiguous dimension
119
+ int iteration_contiguous_;
120
+
121
+ /// Iteration in the strided dimension
122
+ int iteration_strided_;
123
+
124
+ public:
125
+ /// Construct a TileIterator with zero threadblock offset
126
+ CUTLASS_HOST_DEVICE
127
+ RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor
128
+ int thread_id ///< ID of each participating thread
129
+ )
130
+ : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess),
131
+ byte_offset_(0) {
132
+
133
+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id);
134
+
135
+ // initialize pointer
136
+ pointer_ = reinterpret_cast<AccessType *>(ref.data() + ref.offset(thread_offset_base));
137
+
138
+ set_iteration_index(0);
139
+ }
140
+
141
+ /// Overrides the internal iteration index
142
+ CUTLASS_HOST_DEVICE
143
+ void set_iteration_index(int index) {
144
+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
145
+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
146
+ }
147
+
148
+ /// Overrides the internal iteration index
149
+ CUTLASS_HOST_DEVICE
150
+ void set_iteration_num(int num) {
151
+ //Do nothing
152
+ }
153
+
154
+ /// Adds a pointer offset in units of Element
155
+ CUTLASS_HOST_DEVICE
156
+ void add_pointer_offset(LongIndex pointer_offset) {
157
+ byte_offset_ += pointer_offset * sizeof(Element);
158
+ }
159
+
160
+ /// Returns a pointer
161
+ CUTLASS_DEVICE
162
+ AccessType *get() const {
163
+
164
+ AccessType *access_ptr = pointer_;
165
+
166
+ int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ +
167
+ iteration_contiguous_ * ThreadMap::Delta::kContiguous /
168
+ ThreadMap::kElementsPerAccess;
169
+
170
+ char *access_byte_ptr =
171
+ reinterpret_cast<char *>(access_ptr + access_offset);
172
+
173
+ return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
174
+ }
175
+
176
+ /// Advances to the next tile in memory.
177
+ CUTLASS_HOST_DEVICE
178
+ RegularTileAccessIteratorDirectConv &operator++() {
179
+ ++iteration_contiguous_;
180
+
181
+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
182
+ return *this;
183
+
184
+ // Enter here only if (iteration_contiguous_ ==
185
+ // ThreadMap::Iteration::kContiguous)
186
+ iteration_contiguous_ = 0;
187
+ ++iteration_strided_;
188
+
189
+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
190
+ return *this;
191
+ }
192
+
193
+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
194
+ // which means we enter the next tile.
195
+ iteration_strided_ = 0;
196
+
197
+ return *this;
198
+ }
199
+
200
+ /// Advances to the next tile in memory.
201
+ CUTLASS_HOST_DEVICE
202
+ RegularTileAccessIteratorDirectConv operator++(int) {
203
+ RegularTileAccessIteratorDirectConv prev(*this);
204
+ this->operator++();
205
+
206
+ return prev;
207
+ }
208
+
209
+ /// Adds a tile offset in the unit of tile.
210
+ CUTLASS_DEVICE
211
+ void add_tile_offset(TensorCoord const &coord) {
212
+ add_pointer_offset(coord.contiguous() * Shape::kContiguous +
213
+ coord.strided() * ThreadMap::Iterations::kStrided *
214
+ ThreadMap::Delta::kStrided * stride_ * ThreadMap::kElementsPerAccess);
215
+ }
216
+ };
217
+
218
+ ////////////////////////////////////////////////////////////////////////////////
219
+
220
+ /// Tile iterator specialized for congruous arrangements for TensorOps with dynamic_iterations ON
221
+ ///
222
+ ///
223
+ /// Satisfies: ForwardTileIteratorConcept |
224
+ /// ReadableContiguousTileIteratorConcept |
225
+ /// WriteableContiguousTileIteratorConcept
226
+ ///
227
+ template <typename Shape_, typename Element_, int AdvanceRank,
228
+ typename ThreadMap_, int Alignment>
229
+ class RegularTileAccessIteratorDirectConv<
230
+ Shape_, Element_,
231
+ layout::PitchLinear,
232
+ AdvanceRank, ThreadMap_,true, Alignment> {
233
+ public:
234
+ static_assert(
235
+ AdvanceRank == 0 || AdvanceRank == 1,
236
+ "Specialization for pitch-linear iterator may along advance along the "
237
+ "contiguous(rank=0) or strided(rank=1) dimension.");
238
+
239
+ using Shape = Shape_;
240
+ using Element = Element_;
241
+ using Layout = layout::PitchLinear;
242
+ static int const kAdvanceRank = AdvanceRank;
243
+ static int const kAlignment = Alignment;
244
+
245
+ using Index = typename Layout::Index;
246
+ using LongIndex = typename Layout::LongIndex;
247
+ using StrideIndex = typename Layout::Stride::Index;
248
+
249
+ using TensorRef = TensorRef<Element, Layout>;
250
+ using TensorCoord = typename Layout::TensorCoord;
251
+
252
+ using ThreadMap = ThreadMap_;
253
+
254
+ /// Element type per access
255
+ using AccessType = Array<Element, ThreadMap::kElementsPerAccess>;
256
+
257
+ private:
258
+ //
259
+ // Data members
260
+ //
261
+
262
+ /// Stride value
263
+ StrideIndex stride_;
264
+
265
+ /// Internal pointer to first access of tile
266
+ AccessType *pointer_;
267
+
268
+ /// Internal byte offset
269
+ Index byte_offset_;
270
+
271
+ /// Iteration in the contiguous dimension
272
+ int iteration_contiguous_;
273
+
274
+ /// Iteration in the strided dimension
275
+ int iteration_strided_;
276
+
277
+ /// Total iterattions in the strided dimension: Dynamic value
278
+ int total_iteration_strided_;
279
+
280
+ public:
281
+ /// Construct a TileIterator with zero threadblock offset
282
+ CUTLASS_HOST_DEVICE
283
+ RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor
284
+ int thread_id ///< ID of each participating thread
285
+ )
286
+ : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess),
287
+ byte_offset_(0) {
288
+
289
+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id);
290
+
291
+ // initialize pointer
292
+ pointer_ = reinterpret_cast<AccessType *>(ref.data() + ref.offset(thread_offset_base));
293
+
294
+ set_iteration_index(0);
295
+ }
296
+
297
+ /// Overrides the internal iteration index
298
+ CUTLASS_HOST_DEVICE
299
+ void set_iteration_index(int index) {
300
+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
301
+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
302
+ }
303
+
304
+ /// Overrides the internal iteration index
305
+ CUTLASS_HOST_DEVICE
306
+ void set_iteration_num(int num) {
307
+ total_iteration_strided_ = num;
308
+ }
309
+
310
+ /// Adds a pointer offset in units of Element
311
+ CUTLASS_HOST_DEVICE
312
+ void add_pointer_offset(LongIndex pointer_offset) {
313
+ byte_offset_ += pointer_offset * sizeof(Element);
314
+ }
315
+
316
+ /// Returns a pointer
317
+ CUTLASS_DEVICE
318
+ AccessType *get() const {
319
+
320
+ AccessType *access_ptr = pointer_;
321
+
322
+ int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ +
323
+ iteration_contiguous_ * ThreadMap::Delta::kContiguous /
324
+ ThreadMap::kElementsPerAccess;
325
+
326
+ char *access_byte_ptr =
327
+ reinterpret_cast<char *>(access_ptr + access_offset);
328
+
329
+ return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
330
+ }
331
+
332
+ /// Advances to the next tile in memory.
333
+ CUTLASS_HOST_DEVICE
334
+ RegularTileAccessIteratorDirectConv &operator++() {
335
+ ++iteration_contiguous_;
336
+
337
+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
338
+ return *this;
339
+
340
+ // Enter here only if (iteration_contiguous_ ==
341
+ // ThreadMap::Iteration::kContiguous)
342
+ iteration_contiguous_ = 0;
343
+ ++iteration_strided_;
344
+
345
+ if (iteration_strided_ < total_iteration_strided_) {
346
+ return *this;
347
+ }
348
+
349
+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
350
+ // which means we enter the next tile.
351
+ iteration_strided_ = 0;
352
+
353
+ return *this;
354
+ }
355
+
356
+ /// Advances to the next tile in memory.
357
+ CUTLASS_HOST_DEVICE
358
+ RegularTileAccessIteratorDirectConv operator++(int) {
359
+ RegularTileAccessIteratorDirectConv prev(*this);
360
+ this->operator++();
361
+
362
+ return prev;
363
+ }
364
+
365
+ /// Adds a tile offset in the unit of tile.
366
+ CUTLASS_DEVICE
367
+ void add_tile_offset(TensorCoord const &coord) {
368
+ add_pointer_offset(coord.contiguous() * Shape::kContiguous +
369
+ coord.strided() * total_iteration_strided_ * ThreadMap::Delta::kStrided * stride_ *
370
+ ThreadMap::kElementsPerAccess);
371
+ }
372
+ };
373
+
374
+ ////////////////////////////////////////////////////////////////////////////////
375
+
376
+ /// Tile iterator specialized for column major layouts
377
+ ///
378
+ ///
379
+ /// Satisfies: ForwardTileIteratorConcept |
380
+ /// ReadableContiguousTileIteratorConcept |
381
+ /// WriteableContiguousTileIteratorConcept
382
+ ///
383
+ template <typename Shape_, typename Element_, int AdvanceRank,
384
+ typename ThreadMap_,bool Dynamic_iterations, int Alignment >
385
+ class RegularTileAccessIteratorDirectConv<
386
+ Shape_, Element_,
387
+ layout::ColumnMajor,
388
+ AdvanceRank, ThreadMap_, Dynamic_iterations , Alignment> {
389
+ public:
390
+ static_assert(
391
+ AdvanceRank == 0 || AdvanceRank == 1,
392
+ "Specialization for pitch-linear iterator may along advance along the "
393
+ "contiguous(rank=0) or strided(rank=1) dimension.");
394
+
395
+ using Shape = Shape_;
396
+ using Element = Element_;
397
+ using Layout = layout::ColumnMajor;
398
+ static int const kAdvanceRank = AdvanceRank;
399
+ static int const kAlignment = Alignment;
400
+
401
+ using Index = typename Layout::Index;
402
+ using LongIndex = typename Layout::LongIndex;
403
+
404
+ using TensorRef = TensorRef<Element, Layout>;
405
+ using TensorCoord = typename Layout::TensorCoord;
406
+
407
+ using ThreadMap = ThreadMap_;
408
+
409
+ /// Underlying iterator type
410
+ using UnderlyingIterator = RegularTileAccessIteratorDirectConv<
411
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
412
+ layout::PitchLinear,
413
+ (kAdvanceRank == 0 ? 0 : 1),
414
+ ThreadMap_,
415
+ Dynamic_iterations>;
416
+
417
+ using AccessType = typename UnderlyingIterator::AccessType;
418
+
419
+ private:
420
+
421
+ /// Underlying iterator
422
+ UnderlyingIterator iterator_;
423
+
424
+ public:
425
+ /// Construct a TileIterator with zero threadblock offset
426
+ CUTLASS_HOST_DEVICE
427
+ RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor
428
+ int thread_id ///< ID of each participating thread
429
+ )
430
+ : iterator_({ref.data(), ref.stride()}, thread_id) {}
431
+
432
+ /// Overrides the internal iteration index
433
+ CUTLASS_HOST_DEVICE
434
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
435
+
436
+ /// Overrides the internal iteration index
437
+ CUTLASS_HOST_DEVICE
438
+ void set_iteration_num(int num) {
439
+ iterator_.set_iteration_num(num);
440
+ }
441
+
442
+ /// Adds a pointer offset in units of Element
443
+ CUTLASS_HOST_DEVICE
444
+ void add_pointer_offset(LongIndex pointer_offset) {
445
+ iterator_.add_pointer_offset(pointer_offset);
446
+ }
447
+
448
+ /// Returns a pointer
449
+ CUTLASS_HOST_DEVICE
450
+ AccessType *get() const {
451
+ return reinterpret_cast<AccessType *>(iterator_.get());
452
+ }
453
+
454
+ /// Adds a tile offset
455
+ CUTLASS_DEVICE
456
+ void add_tile_offset(TensorCoord const &coord) {
457
+ iterator_.add_tile_offset({coord.row(), coord.column()});
458
+ }
459
+
460
+ /// Advances to the next tile in memory.
461
+ CUTLASS_HOST_DEVICE
462
+ RegularTileAccessIteratorDirectConv &operator++() {
463
+ ++iterator_;
464
+ return *this;
465
+ }
466
+
467
+ /// Advances to the next tile in memory.
468
+ CUTLASS_HOST_DEVICE
469
+ RegularTileAccessIteratorDirectConv operator++(int) {
470
+ RegularTileAccessIteratorDirectConv prev(*this);
471
+ ++iterator_;
472
+
473
+ return prev;
474
+ }
475
+ };
476
+
477
+
478
+ ////////////////////////////////////////////////////////////////////////////////
479
+
480
+ /// Tile iterator specialized for row major layouts
481
+ ///
482
+ ///
483
+ /// Satisfies: ForwardTileIteratorConcept |
484
+ /// ReadableContiguousTileIteratorConcept |
485
+ /// WriteableContiguousTileIteratorConcept
486
+ ///
487
+ template <typename Shape_, typename Element_, int AdvanceRank,
488
+ typename ThreadMap_,bool Dynamic_iterations, int Alignment>
489
+ class RegularTileAccessIteratorDirectConv<
490
+ Shape_, Element_,
491
+ layout::RowMajor,
492
+ AdvanceRank, ThreadMap_, Dynamic_iterations, Alignment> {
493
+ public:
494
+ static_assert(
495
+ AdvanceRank == 0 || AdvanceRank == 1,
496
+ "Specialization for pitch-linear iterator may along advance along the "
497
+ "contiguous(rank=0) or strided(rank=1) dimension.");
498
+
499
+ using Shape = Shape_;
500
+ using Element = Element_;
501
+ using Layout = layout::RowMajor;
502
+ static int const kAdvanceRank = AdvanceRank;
503
+ static int const kAlignment = Alignment;
504
+
505
+ using Index = typename Layout::Index;
506
+ using LongIndex = typename Layout::LongIndex;
507
+
508
+ using TensorRef = TensorRef<Element, Layout>;
509
+ using TensorCoord = typename Layout::TensorCoord;
510
+
511
+ using ThreadMap = ThreadMap_;
512
+
513
+ /// Underlying iterator type
514
+ using UnderlyingIterator = RegularTileAccessIteratorDirectConv<
515
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
516
+ layout::PitchLinear,
517
+ (kAdvanceRank == 0 ? 1 : 0),
518
+ ThreadMap_,
519
+ Dynamic_iterations>;
520
+
521
+ using AccessType = typename UnderlyingIterator::AccessType;
522
+
523
+ private:
524
+
525
+ /// Underlying iterator
526
+ UnderlyingIterator iterator_;
527
+
528
+ public:
529
+ /// Construct a TileIterator with zero threadblock offset
530
+ CUTLASS_HOST_DEVICE
531
+ RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor
532
+ int thread_id ///< ID of each participating thread
533
+ )
534
+ : iterator_({ref.data(), ref.stride()}, thread_id) {}
535
+
536
+ /// Overrides the internal iteration index
537
+ CUTLASS_HOST_DEVICE
538
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
539
+
540
+ /// Overrides the internal iteration index
541
+ CUTLASS_HOST_DEVICE
542
+ void set_iteration_num(int num) {
543
+ iterator_.set_iteration_num(num);
544
+ }
545
+
546
+ /// Adds a pointer offset in units of Element
547
+ CUTLASS_HOST_DEVICE
548
+ void add_pointer_offset(LongIndex pointer_offset) {
549
+ iterator_.add_pointer_offset(pointer_offset);
550
+ }
551
+
552
+ /// Returns a pointer
553
+ CUTLASS_HOST_DEVICE
554
+ AccessType *get() const {
555
+ return reinterpret_cast<AccessType *>(iterator_.get());
556
+ }
557
+
558
+ /// Adds a tile offset
559
+ CUTLASS_DEVICE
560
+ void add_tile_offset(TensorCoord const &coord) {
561
+ iterator_.add_tile_offset({coord.column(), coord.row()});
562
+ }
563
+
564
+ /// Advances to the next tile in memory.
565
+ CUTLASS_HOST_DEVICE
566
+ RegularTileAccessIteratorDirectConv &operator++() {
567
+ ++iterator_;
568
+ return *this;
569
+ }
570
+
571
+ /// Advances to the next tile in memory.
572
+ CUTLASS_HOST_DEVICE
573
+ RegularTileAccessIteratorDirectConv operator++(int) {
574
+ RegularTileAccessIteratorDirectConv prev(*this);
575
+ ++iterator_;
576
+
577
+ return prev;
578
+ }
579
+ };
580
+
581
+ ////////////////////////////////////////////////////////////////////////////////
582
+
583
+ } // namespace threadblock
584
+ } // namespace transform
585
+ } // namespace cutlass
586
+
587
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h ADDED
@@ -0,0 +1,821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates implementing computing the addresses of storing of tiles
33
+ from pitch-linear rank=2 tensors.
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/array.h"
39
+ #include "cutlass/cutlass.h"
40
+ #include "cutlass/layout/pitch_linear.h"
41
+ #include "cutlass/layout/tensor_op_multiplicand_sm75.h"
42
+ #include "cutlass/matrix_coord.h"
43
+ #include "cutlass/matrix_shape.h"
44
+ #include "cutlass/tensor_ref.h"
45
+ #include "cutlass/transform/threadblock/regular_tile_access_iterator.h"
46
+
47
+ ////////////////////////////////////////////////////////////////////////////////
48
+
49
+ namespace cutlass {
50
+ namespace transform {
51
+ namespace threadblock {
52
+
53
+ ////////////////////////////////////////////////////////////////////////////////
54
+
55
+ /// Tile iterator specialized for congruous arrangements for TensorOps
56
+ ///
57
+ ///
58
+ /// Satisfies: ForwardTileIteratorConcept |
59
+ /// ReadableContiguousTileIteratorConcept |
60
+ /// WriteableContiguousTileIteratorConcept
61
+ ///
62
+ template <typename Shape_, typename Element_, int AdvanceRank,
63
+ typename ThreadMap_, int Alignment, int Crosswise>
64
+ class RegularTileAccessIterator<
65
+ Shape_, Element_,
66
+ layout::TensorOpMultiplicandCongruous<sizeof_bits<Element_>::value,
67
+ Crosswise>,
68
+ AdvanceRank, ThreadMap_, Alignment> {
69
+ public:
70
+ static_assert(
71
+ AdvanceRank == 0 || AdvanceRank == 1,
72
+ "Specialization for pitch-linear iterator may along advance along the "
73
+ "contiguous(rank=0) or strided(rank=1) dimension.");
74
+
75
+ using Shape = Shape_;
76
+ using Element = Element_;
77
+ using Layout =
78
+ layout::TensorOpMultiplicandCongruous<sizeof_bits<Element_>::value,
79
+ Crosswise>;
80
+ static int const kAdvanceRank = AdvanceRank;
81
+ static int const kAlignment = Alignment;
82
+ static int const kCrosswise = Crosswise;
83
+
84
+ using Index = typename Layout::Index;
85
+ using LongIndex = typename Layout::LongIndex;
86
+ using StrideIndex = typename Layout::Stride::Index;
87
+
88
+ using TensorRef = TensorRef<Element, Layout>;
89
+ using TensorCoord = typename Layout::TensorCoord;
90
+
91
+ using ThreadMap = ThreadMap_;
92
+
93
+ /// Internal details made public to facilitate introspection
94
+ struct Detail {
95
+ /// This iterator is specialized for an access size that is 128 bits in
96
+ /// length.
97
+ static int const kAccessSizeInBits = 128;
98
+
99
+ static_assert(sizeof_bits<Element_>::value *
100
+ ThreadMap::kElementsPerAccess ==
101
+ kAccessSizeInBits,
102
+ "This iterator requires a policy whose access size is 128bs");
103
+
104
+ ///< Number of pointers
105
+ static int const kPointerCount =
106
+ (ThreadMap::Iterations::kStrided > 1 ? 2 : 1);
107
+ };
108
+
109
+ /// Element type per access
110
+ using AccessType = Array<Element, Layout::kElementsPerAccess>;
111
+
112
+ private:
113
+ //
114
+ // Data members
115
+ //
116
+
117
+ /// Stride value
118
+ StrideIndex stride_;
119
+
120
+ /// Internal pointer to first access of tile
121
+ AccessType *pointer_[Detail::kPointerCount];
122
+
123
+ /// Internal byte offset
124
+ Index byte_offset_;
125
+
126
+ /// Iteration in the contiguous dimension
127
+ int iteration_contiguous_;
128
+
129
+ /// Iteration in the strided dimension
130
+ int iteration_strided_;
131
+
132
+ public:
133
+ /// Construct a TileIterator with zero threadblock offset
134
+ CUTLASS_HOST_DEVICE
135
+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
136
+ int thread_id ///< ID of each participating thread
137
+ )
138
+ : stride_(ref.stride(0) * Layout::kFactor / Layout::kElementsPerAccess),
139
+ byte_offset_(0) {
140
+ layout::PitchLinearCoord thread_offset_base =
141
+ ThreadMap::initial_offset(thread_id);
142
+
143
+ CUTLASS_PRAGMA_UNROLL
144
+ for (int i = 0; i < Detail::kPointerCount; ++i) {
145
+ // This is the offset of a thread within a threadblock tile for a specific
146
+ // pointer (units of elements)
147
+ layout::PitchLinearCoord thread_offset_in_threadblock_tile =
148
+ thread_offset_base +
149
+ layout::PitchLinearCoord{
150
+ 0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i};
151
+
152
+ // initialize pointer
153
+ pointer_[i] = reinterpret_cast<AccessType *>(
154
+ ref.data() + ref.offset(thread_offset_in_threadblock_tile));
155
+ }
156
+
157
+ set_iteration_index(0);
158
+ }
159
+
160
+ /// Overrides the internal iteration index
161
+ CUTLASS_HOST_DEVICE
162
+ void set_iteration_index(int index) {
163
+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
164
+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
165
+ }
166
+
167
+ /// Adds a pointer offset in units of Element
168
+ CUTLASS_HOST_DEVICE
169
+ void add_pointer_offset(LongIndex pointer_offset) {
170
+ byte_offset_ += pointer_offset * sizeof(Element);
171
+ }
172
+
173
+ /// Returns a pointer
174
+ CUTLASS_HOST_DEVICE
175
+ AccessType *get() const {
176
+ AccessType *access_ptr = pointer_[iteration_strided_ & 1];
177
+ int stride_idx = (iteration_strided_ & ~1);
178
+
179
+ int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ / Layout::kFactor +
180
+ iteration_contiguous_ * ThreadMap::Delta::kContiguous /
181
+ ThreadMap::kElementsPerAccess;
182
+
183
+ char *access_byte_ptr =
184
+ reinterpret_cast<char *>(access_ptr + access_offset);
185
+ return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
186
+ }
187
+
188
+ /// Advances to the next tile in memory.
189
+ CUTLASS_HOST_DEVICE
190
+ RegularTileAccessIterator &operator++() {
191
+ ++iteration_contiguous_;
192
+
193
+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
194
+ return *this;
195
+
196
+ // Enter here only if (iteration_contiguous_ ==
197
+ // ThreadMap::Iteration::kContiguous)
198
+ iteration_contiguous_ = 0;
199
+ ++iteration_strided_;
200
+
201
+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
202
+ return *this;
203
+ }
204
+
205
+ // Enter here only if (iteration_strided_ == ThreadMap::Iteration::kStrided)
206
+ // which means we enter the next tile.
207
+ iteration_strided_ = 0;
208
+
209
+ return *this;
210
+ }
211
+
212
+ /// Advances to the next tile in memory.
213
+ CUTLASS_HOST_DEVICE
214
+ RegularTileAccessIterator operator++(int) {
215
+ RegularTileAccessIterator prev(*this);
216
+ this->operator++();
217
+
218
+ return prev;
219
+ }
220
+
221
+ /// Adds a tile offset
222
+ CUTLASS_DEVICE
223
+ void add_tile_offset(TensorCoord const &coord) {
224
+ add_pointer_offset(coord.contiguous() * Shape::kContiguous * Layout::kFactor +
225
+ coord.strided() * Shape::kStrided * stride_ *
226
+ Layout::kElementsPerAccess / Layout::kFactor);
227
+ }
228
+ };
229
+
230
+ ////////////////////////////////////////////////////////////////////////////////
231
+
232
+ /// Tile Iterator specialized for column-major congruous TensorOp formats.
233
+ ///
234
+ ///
235
+ /// Satisfies: ForwardTileIteratorConcept |
236
+ /// ReadableContiguousTileIteratorConcept |
237
+ /// WriteableContiguousTileIteratorConcept
238
+ ///
239
+ template <typename Shape_, typename Element_, int AdvanceRank,
240
+ typename ThreadMap_, int Alignment, int Crosswise>
241
+ class RegularTileAccessIterator<
242
+ Shape_, Element_,
243
+ layout::ColumnMajorTensorOpMultiplicandCongruous<
244
+ sizeof_bits<Element_>::value, Crosswise>,
245
+ AdvanceRank, ThreadMap_, Alignment> {
246
+ public:
247
+ static_assert(
248
+ AdvanceRank == 0 || AdvanceRank == 1,
249
+ "Specialization for column-major iterator may along advance along the "
250
+ "columns(rank=0) or rows(rank=1) dimension.");
251
+
252
+ using Shape = Shape_;
253
+ using Element = Element_;
254
+ using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous<
255
+ sizeof_bits<Element_>::value, Crosswise>;
256
+ static int const kAdvanceRank = AdvanceRank;
257
+ static int const kAlignment = Alignment;
258
+
259
+ using Index = typename Layout::Index;
260
+ using LongIndex = typename Layout::LongIndex;
261
+
262
+ using TensorRef = TensorRef<Element, Layout>;
263
+ using TensorCoord = typename Layout::TensorCoord;
264
+
265
+ using ThreadMap = ThreadMap_;
266
+
267
+ /// Underlying iterator type
268
+ using UnderlyingIterator = RegularTileAccessIterator<
269
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
270
+ layout::TensorOpMultiplicandCongruous<sizeof_bits<Element_>::value,
271
+ Crosswise>,
272
+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>;
273
+
274
+ using AccessType = typename UnderlyingIterator::AccessType;
275
+
276
+ private:
277
+ /// Underlying iterator
278
+ UnderlyingIterator iterator_;
279
+
280
+ public:
281
+ /// Construct a TileIterator with zero threadblock offset
282
+ CUTLASS_HOST_DEVICE
283
+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
284
+ int thread_id ///< ID of each participating thread
285
+ )
286
+ : iterator_({ref.data(), ref.stride()}, thread_id) {}
287
+
288
+ /// Overrides the internal iteration index
289
+ CUTLASS_HOST_DEVICE
290
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
291
+
292
+ /// Adds a pointer offset in units of Element
293
+ CUTLASS_HOST_DEVICE
294
+ void add_pointer_offset(LongIndex pointer_offset) {
295
+ iterator_.add_pointer_offset(pointer_offset);
296
+ }
297
+
298
+ /// Returns a pointer
299
+ CUTLASS_HOST_DEVICE
300
+ AccessType *get() const {
301
+ return reinterpret_cast<AccessType *>(iterator_.get());
302
+ }
303
+
304
+ /// Adds a tile offset
305
+ CUTLASS_DEVICE
306
+ void add_tile_offset(TensorCoord const &coord) {
307
+ iterator_.add_tile_offset({coord.row(), coord.column()});
308
+ }
309
+
310
+ /// Advances to the next tile in memory.
311
+ CUTLASS_HOST_DEVICE
312
+ RegularTileAccessIterator &operator++() {
313
+ ++iterator_;
314
+ return *this;
315
+ }
316
+
317
+ /// Advances to the next tile in memory.
318
+ CUTLASS_HOST_DEVICE
319
+ RegularTileAccessIterator operator++(int) {
320
+ RegularTileAccessIterator prev(*this);
321
+ ++iterator_;
322
+
323
+ return prev;
324
+ }
325
+ };
326
+
327
+ ////////////////////////////////////////////////////////////////////////////////
328
+
329
+ /// Tile Iterator specialized for row-major congruous TensorOp formats.
330
+ ///
331
+ ///
332
+ /// Satisfies: ForwardTileIteratorConcept |
333
+ /// ReadableContiguousTileIteratorConcept |
334
+ /// WriteableContiguousTileIteratorConcept
335
+ ///
336
+ template <typename Shape_, typename Element_, int AdvanceRank,
337
+ typename ThreadMap_, int Alignment, int Crosswise>
338
+ class RegularTileAccessIterator<
339
+ Shape_, Element_,
340
+ layout::RowMajorTensorOpMultiplicandCongruous<sizeof_bits<Element_>::value,
341
+ Crosswise>,
342
+ AdvanceRank, ThreadMap_, Alignment> {
343
+ public:
344
+ static_assert(
345
+ AdvanceRank == 0 || AdvanceRank == 1,
346
+ "Specialization for row-major iterator may along advance along the "
347
+ "columns(rank=0) or rows(rank=1) dimension.");
348
+
349
+ using Shape = Shape_;
350
+ using Element = Element_;
351
+ using Layout = layout::RowMajorTensorOpMultiplicandCongruous<
352
+ sizeof_bits<Element_>::value, Crosswise>;
353
+ static int const kAdvanceRank = AdvanceRank;
354
+ static int const kAlignment = Alignment;
355
+
356
+ using Index = typename Layout::Index;
357
+ using LongIndex = typename Layout::LongIndex;
358
+
359
+ using TensorRef = TensorRef<Element, Layout>;
360
+ using TensorCoord = typename Layout::TensorCoord;
361
+
362
+ using ThreadMap = ThreadMap_;
363
+
364
+ /// Underlying iterator type
365
+ using UnderlyingIterator = RegularTileAccessIterator<
366
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
367
+ layout::TensorOpMultiplicandCongruous<sizeof_bits<Element_>::value,
368
+ Crosswise>,
369
+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>;
370
+
371
+ using AccessType = typename UnderlyingIterator::AccessType;
372
+
373
+ private:
374
+ /// Underlying iterator
375
+ UnderlyingIterator iterator_;
376
+
377
+ public:
378
+ /// Construct a TileIterator with zero threadblock offset
379
+ CUTLASS_HOST_DEVICE
380
+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
381
+ int thread_id ///< ID of each participating thread
382
+ )
383
+ : iterator_({ref.data(), ref.stride()}, thread_id) {}
384
+
385
+ /// Overrides the internal iteration index
386
+ CUTLASS_HOST_DEVICE
387
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
388
+
389
+ /// Adds a pointer offset in units of Element
390
+ CUTLASS_HOST_DEVICE
391
+ void add_pointer_offset(LongIndex pointer_offset) {
392
+ iterator_.add_pointer_offset(pointer_offset);
393
+ }
394
+
395
+ /// Returns a pointer
396
+ CUTLASS_HOST_DEVICE
397
+ AccessType *get() const {
398
+ return reinterpret_cast<AccessType *>(iterator_.get());
399
+ }
400
+
401
+ /// Adds a tile offset
402
+ CUTLASS_DEVICE
403
+ void add_tile_offset(TensorCoord const &coord) {
404
+ iterator_.add_tile_offset({coord.column(), coord.row()});
405
+ }
406
+
407
+ /// Advances to the next tile in memory.
408
+ CUTLASS_HOST_DEVICE
409
+ RegularTileAccessIterator &operator++() {
410
+ ++iterator_;
411
+ return *this;
412
+ }
413
+
414
+ /// Advances to the next tile in memory.
415
+ CUTLASS_HOST_DEVICE
416
+ RegularTileAccessIterator operator++(int) {
417
+ RegularTileAccessIterator prev(*this);
418
+ ++iterator_;
419
+
420
+ return prev;
421
+ }
422
+ };
423
+
424
+ ////////////////////////////////////////////////////////////////////////////////
425
+
426
+ /// Tile iterator specialized for crosswise arrangements for TensorOps
427
+ ///
428
+ ///
429
+ /// Satisfies: ForwardTileIteratorConcept |
430
+ /// ReadableContiguousTileIteratorConcept |
431
+ /// WriteableContiguousTileIteratorConcept
432
+ ///
433
+ template <typename Shape_, typename Element_, int AdvanceRank,
434
+ typename ThreadMap_, int Alignment, int Crosswise>
435
+ class RegularTileAccessIterator<Shape_, Element_,
436
+ layout::TensorOpMultiplicandCrosswise<
437
+ sizeof_bits<Element_>::value, Crosswise>,
438
+ AdvanceRank, ThreadMap_, Alignment> {
439
+ public:
440
+ static_assert(
441
+ AdvanceRank == 0 || AdvanceRank == 1,
442
+ "Specialization for pitch-linear iterator may along advance along the "
443
+ "contiguous(rank=0) or strided(rank=1) dimension.");
444
+
445
+ using Shape = Shape_;
446
+ using Element = Element_;
447
+ using Layout =
448
+ layout::TensorOpMultiplicandCrosswise<sizeof_bits<Element_>::value,
449
+ Crosswise>;
450
+ static int const kAdvanceRank = AdvanceRank;
451
+ static int const kAlignment = Alignment;
452
+ static int const kCrosswise = Crosswise;
453
+
454
+ using Index = typename Layout::Index;
455
+ using LongIndex = typename Layout::LongIndex;
456
+ using StrideIndex = typename Layout::Stride::Index;
457
+
458
+ using TensorRef = TensorRef<Element, Layout>;
459
+ using TensorCoord = typename Layout::TensorCoord;
460
+
461
+ using ThreadMap = ThreadMap_;
462
+
463
+ static_assert(!(ThreadMap::Delta::kContiguous % kCrosswise),
464
+ "kCrosswise is the smallest unit in the contiguous dimension "
465
+ "for shared memory swizzling.");
466
+
467
+ /// Internal details made public to facilitate introspection
468
+ struct Detail {
469
+ /// This iterator is specialized for an access size that is 128 bits in
470
+ /// length.
471
+ static int const kAccessSizeInBits = 128;
472
+
473
+ static_assert(sizeof_bits<Element_>::value *
474
+ ThreadMap::kElementsPerAccess ==
475
+ kAccessSizeInBits,
476
+ "This iterator requires a policy whose access size is 128bs");
477
+
478
+ /// Number of pointers
479
+ ///
480
+ /// Note:TN kblock32 layouts only needs 1 pointer, but strangely
481
+ /// reducing pointer count hurts perfomrnace
482
+ static int const kPointerCount =
483
+ (ThreadMap::Iterations::kStrided > 1 ? 2 : 1);
484
+ };
485
+
486
+ /// Element type per access
487
+ using AccessType = Array<Element, Layout::kElementsPerAccess>;
488
+
489
+ private:
490
+ //
491
+ // Data members
492
+ //
493
+
494
+ /// Total number of sections. The memory is divided into stages. One stage
495
+ /// can store one tile. Stage is divided into sections. Interleaved layout
496
+ /// can have multiple sections in a stage. The rest layout only has one section
497
+ /// in a stage.
498
+ int sections_;
499
+
500
+ /// Sections that a stage has
501
+ int sections_per_stage_;
502
+
503
+ /// Stride value
504
+ StrideIndex stride_;
505
+
506
+ /// Internal pointer to first access of tile
507
+ AccessType *pointer_[Detail::kPointerCount];
508
+
509
+ /// Internal byte offset
510
+ Index byte_offset_;
511
+
512
+ /// Iteration in the contiguous dimension
513
+ int iteration_contiguous_;
514
+
515
+ /// Iteration in the strided dimension
516
+ int iteration_strided_;
517
+
518
+ public:
519
+ /// Construct a TileIterator with zero threadblock offset
520
+ CUTLASS_HOST_DEVICE
521
+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
522
+ int thread_id ///< ID of each participating thread
523
+ )
524
+ : sections_(ref.stride(0) / kCrosswise),
525
+ sections_per_stage_(Shape::kContiguous / kCrosswise),
526
+ // stride_ = kCrosswise x sections_ x kFactor
527
+ stride_(ref.stride(0) * Layout::kFactor / Layout::kElementsPerAccess),
528
+ byte_offset_(0) {
529
+ layout::PitchLinearCoord thread_offset_base =
530
+ ThreadMap::initial_offset(thread_id);
531
+
532
+ CUTLASS_PRAGMA_UNROLL
533
+ for (int i = 0; i < Detail::kPointerCount; ++i) {
534
+ // This is the offset of a thread within a threadblock tile for a specific
535
+ // pointer (units of elements)
536
+ layout::PitchLinearCoord thread_offset_in_threadblock_tile =
537
+ thread_offset_base +
538
+ layout::PitchLinearCoord{
539
+ 0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i};
540
+ // initialize pointer
541
+ pointer_[i] = reinterpret_cast<AccessType *>(ref.data()) +
542
+ ref.offset(thread_offset_in_threadblock_tile) /
543
+ Layout::kElementsPerAccess;
544
+ }
545
+
546
+ set_iteration_index(0);
547
+ }
548
+
549
+ /// Overrides the internal iteration index
550
+ CUTLASS_HOST_DEVICE
551
+ void set_iteration_index(int index) {
552
+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
553
+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
554
+ }
555
+
556
+ /// Adds a pointer offset in units of Element
557
+ CUTLASS_HOST_DEVICE
558
+ void add_pointer_offset(LongIndex pointer_offset) {
559
+ byte_offset_ += pointer_offset * sizeof_bits<Element>::value / 8;
560
+ }
561
+
562
+ /// Returns a pointer
563
+ CUTLASS_HOST_DEVICE
564
+ AccessType *get() const {
565
+ AccessType *access_ptr = pointer_[iteration_strided_ & 1];
566
+ int stride_idx = (iteration_strided_ & ~1);
567
+
568
+ int access_offset =
569
+ stride_idx * ThreadMap::Delta::kStrided * stride_ / Layout::kFactor +
570
+ // kCrosswise elements in the contiguous dimension would span to a
571
+ // shared memory cache line.
572
+ iteration_contiguous_ * (ThreadMap::Delta::kContiguous / kCrosswise) *
573
+ Layout::TileShape::kContiguous;
574
+ char *access_byte_ptr =
575
+ reinterpret_cast<char *>(access_ptr + access_offset);
576
+ return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
577
+ }
578
+
579
+ /// Advances to the next tile in memory.
580
+ CUTLASS_HOST_DEVICE
581
+ RegularTileAccessIterator &operator++() {
582
+ ++iteration_contiguous_;
583
+
584
+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
585
+ return *this;
586
+
587
+ // Enter here only if (iteration_contiguous_ ==
588
+ // ThreadMap::Iteration::kContiguous)
589
+ iteration_contiguous_ = 0;
590
+ ++iteration_strided_;
591
+
592
+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
593
+ return *this;
594
+ }
595
+
596
+ // Enter here only if (iteration_strided_ == ThreadMap::Iteration::kStrided)
597
+ // which means we enter the next section.
598
+ iteration_strided_ = 0;
599
+
600
+ return *this;
601
+ }
602
+
603
+ /// Advances to the next tile in memory.
604
+ CUTLASS_HOST_DEVICE
605
+ RegularTileAccessIterator operator++(int) {
606
+ RegularTileAccessIterator prev(*this);
607
+ this->operator++();
608
+
609
+ return prev;
610
+ }
611
+
612
+ /// Adds a tile offset
613
+ CUTLASS_DEVICE
614
+ void add_tile_offset(TensorCoord const &coord) {
615
+ add_pointer_offset(coord.contiguous() * sections_per_stage_ * stride_ *
616
+ ThreadMap::kElementsPerAccess / sections_ +
617
+ coord.strided() * Shape::kStrided * stride_ *
618
+ Layout::kElementsPerAccess / Layout::kFactor);
619
+ }
620
+ };
621
+
622
+ ////////////////////////////////////////////////////////////////////////////////
623
+
624
+ /// Tile Iterator specialized for column-major crosswise TensorOp formats.
625
+ ///
626
+ ///
627
+ /// Satisfies: ForwardTileIteratorConcept |
628
+ /// ReadableContiguousTileIteratorConcept |
629
+ /// WriteableContiguousTileIteratorConcept
630
+ ///
631
+ template <typename Shape_, typename Element_, int AdvanceRank,
632
+ typename ThreadMap_, int Alignment, int Crosswise>
633
+ class RegularTileAccessIterator<
634
+ Shape_, Element_,
635
+ layout::ColumnMajorTensorOpMultiplicandCrosswise<
636
+ sizeof_bits<Element_>::value, Crosswise>,
637
+ AdvanceRank, ThreadMap_, Alignment> {
638
+ public:
639
+ static_assert(
640
+ AdvanceRank == 0 || AdvanceRank == 1,
641
+ "Specialization for column-major iterator may along advance along the "
642
+ "columns(rank=0) or rows(rank=1) dimension.");
643
+
644
+ using Shape = Shape_;
645
+ using Element = Element_;
646
+ using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise<
647
+ sizeof_bits<Element_>::value, Crosswise>;
648
+ static int const kAdvanceRank = AdvanceRank;
649
+ static int const kAlignment = Alignment;
650
+
651
+ using Index = typename Layout::Index;
652
+ using LongIndex = typename Layout::LongIndex;
653
+
654
+ using TensorRef = TensorRef<Element, Layout>;
655
+ using TensorCoord = typename Layout::TensorCoord;
656
+
657
+ using ThreadMap = ThreadMap_;
658
+
659
+ /// Underlying iterator type
660
+ using UnderlyingIterator = RegularTileAccessIterator<
661
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
662
+ layout::TensorOpMultiplicandCrosswise<sizeof_bits<Element_>::value,
663
+ Crosswise>,
664
+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>;
665
+
666
+ using AccessType = typename UnderlyingIterator::AccessType;
667
+
668
+ private:
669
+ /// Underlying iterator
670
+ UnderlyingIterator iterator_;
671
+
672
+ public:
673
+ /// Construct a TileIterator with zero threadblock offset
674
+ CUTLASS_HOST_DEVICE
675
+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
676
+ int thread_id ///< ID of each participating thread
677
+ )
678
+ : iterator_({ref.data(), ref.stride()}, thread_id) {}
679
+
680
+ /// Overrides the internal iteration index
681
+ CUTLASS_HOST_DEVICE
682
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
683
+
684
+ /// Adds a pointer offset in units of Element
685
+ CUTLASS_HOST_DEVICE
686
+ void add_pointer_offset(LongIndex pointer_offset) {
687
+ iterator_.add_pointer_offset(pointer_offset);
688
+ }
689
+
690
+ /// Returns a pointer
691
+ CUTLASS_HOST_DEVICE
692
+ AccessType *get() const {
693
+ return reinterpret_cast<AccessType *>(iterator_.get());
694
+ }
695
+
696
+ /// Adds a tile offset
697
+ CUTLASS_DEVICE
698
+ void add_tile_offset(TensorCoord const &coord) {
699
+ iterator_.add_tile_offset({coord.row(), coord.column()});
700
+ }
701
+
702
+ /// Advances to the next tile in memory.
703
+ CUTLASS_HOST_DEVICE
704
+ RegularTileAccessIterator &operator++() {
705
+ ++iterator_;
706
+ return *this;
707
+ }
708
+
709
+ /// Advances to the next tile in memory.
710
+ CUTLASS_HOST_DEVICE
711
+ RegularTileAccessIterator operator++(int) {
712
+ RegularTileAccessIterator prev(*this);
713
+ ++iterator_;
714
+
715
+ return prev;
716
+ }
717
+ };
718
+
719
+ ////////////////////////////////////////////////////////////////////////////////
720
+
721
+ /// Tile Iterator specialized for row-major crosswise TensorOp formats.
722
+ ///
723
+ ///
724
+ /// Satisfies: ForwardTileIteratorConcept |
725
+ /// ReadableContiguousTileIteratorConcept |
726
+ /// WriteableContiguousTileIteratorConcept
727
+ ///
728
+ template <typename Shape_, typename Element_, int AdvanceRank,
729
+ typename ThreadMap_, int Alignment, int Crosswise>
730
+ class RegularTileAccessIterator<Shape_, Element_,
731
+ layout::RowMajorTensorOpMultiplicandCrosswise<
732
+ sizeof_bits<Element_>::value, Crosswise>,
733
+ AdvanceRank, ThreadMap_, Alignment> {
734
+ public:
735
+ static_assert(
736
+ AdvanceRank == 0 || AdvanceRank == 1,
737
+ "Specialization for row-major iterator may along advance along the "
738
+ "columns(rank=0) or rows(rank=1) dimension.");
739
+
740
+ using Shape = Shape_;
741
+ using Element = Element_;
742
+ using Layout = layout::RowMajorTensorOpMultiplicandCrosswise<
743
+ sizeof_bits<Element_>::value, Crosswise>;
744
+ static int const kAdvanceRank = AdvanceRank;
745
+ static int const kAlignment = Alignment;
746
+
747
+ using Index = typename Layout::Index;
748
+ using LongIndex = typename Layout::LongIndex;
749
+
750
+ using TensorRef = TensorRef<Element, Layout>;
751
+ using TensorCoord = typename Layout::TensorCoord;
752
+
753
+ using ThreadMap = ThreadMap_;
754
+
755
+ /// Underlying iterator type
756
+ using UnderlyingIterator = RegularTileAccessIterator<
757
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
758
+ layout::TensorOpMultiplicandCrosswise<sizeof_bits<Element_>::value,
759
+ Crosswise>,
760
+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>;
761
+
762
+ using AccessType = typename UnderlyingIterator::AccessType;
763
+
764
+ private:
765
+ /// Underlying iterator
766
+ UnderlyingIterator iterator_;
767
+
768
+ public:
769
+ /// Construct a TileIterator with zero threadblock offset
770
+ CUTLASS_HOST_DEVICE
771
+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
772
+ int thread_id ///< ID of each participating thread
773
+ )
774
+ : iterator_({ref.data(), ref.stride()}, thread_id) {}
775
+
776
+ /// Overrides the internal iteration index
777
+ CUTLASS_HOST_DEVICE
778
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
779
+
780
+ /// Adds a pointer offset in units of Element
781
+ CUTLASS_HOST_DEVICE
782
+ void add_pointer_offset(LongIndex pointer_offset) {
783
+ iterator_.add_pointer_offset(pointer_offset);
784
+ }
785
+
786
+ /// Returns a pointer
787
+ CUTLASS_HOST_DEVICE
788
+ AccessType *get() const {
789
+ return reinterpret_cast<AccessType *>(iterator_.get());
790
+ }
791
+
792
+ /// Adds a tile offset
793
+ CUTLASS_DEVICE
794
+ void add_tile_offset(TensorCoord const &coord) {
795
+ iterator_.add_tile_offset({coord.column(), coord.row()});
796
+ }
797
+
798
+ /// Advances to the next tile in memory.
799
+ CUTLASS_HOST_DEVICE
800
+ RegularTileAccessIterator &operator++() {
801
+ ++iterator_;
802
+ return *this;
803
+ }
804
+
805
+ /// Advances to the next tile in memory.
806
+ CUTLASS_HOST_DEVICE
807
+ RegularTileAccessIterator operator++(int) {
808
+ RegularTileAccessIterator prev(*this);
809
+ ++iterator_;
810
+
811
+ return prev;
812
+ }
813
+ };
814
+
815
+ ////////////////////////////////////////////////////////////////////////////////
816
+
817
+ } // namespace threadblock
818
+ } // namespace transform
819
+ } // namespace cutlass
820
+
821
+ ////////////////////////////////////////////////////////////////////////////////
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h ADDED
@@ -0,0 +1,1532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Templates implementing computing the addresses of storing of tiles
33
+ from pitch-linear rank=2 tensors.
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include "cutlass/array.h"
39
+ #include "cutlass/cutlass.h"
40
+ #include "cutlass/layout/pitch_linear.h"
41
+ #include "cutlass/layout/tensor_op_multiplicand_sm75.h"
42
+ #include "cutlass/layout/tensor_op_multiplicand_sm80.h"
43
+ #include "cutlass/matrix_coord.h"
44
+ #include "cutlass/matrix_shape.h"
45
+ #include "cutlass/tensor_ref.h"
46
+ #include "cutlass/transform/threadblock/regular_tile_access_iterator.h"
47
+
48
+ ////////////////////////////////////////////////////////////////////////////////
49
+
50
+ namespace cutlass {
51
+ namespace transform {
52
+ namespace threadblock {
53
+
54
+ ////////////////////////////////////////////////////////////////////////////////
55
+
56
+ /// Tile iterator specialized for congruous arrangements for TensorOps
57
+ ///
58
+ ///
59
+ /// Satisfies: ForwardTileIteratorConcept |
60
+ /// ReadableContiguousTileIteratorConcept |
61
+ /// WriteableContiguousTileIteratorConcept
62
+ ///
63
+ template <typename Shape_, typename Element_, int AdvanceRank,
64
+ typename ThreadMap_, int Alignment>
65
+ class RegularTileAccessIterator<
66
+ Shape_, Element_,
67
+ layout::TensorOpMultiplicandCongruous64b,
68
+ AdvanceRank, ThreadMap_, Alignment> {
69
+ public:
70
+ static_assert(
71
+ AdvanceRank == 0 || AdvanceRank == 1,
72
+ "Specialization for pitch-linear iterator may along advance along the "
73
+ "contiguous(rank=0) or strided(rank=1) dimension.");
74
+
75
+ using Shape = Shape_;
76
+ using Element = Element_;
77
+ using Layout = layout::TensorOpMultiplicandCongruous64b;
78
+ static int const kAdvanceRank = AdvanceRank;
79
+ static int const kAlignment = Alignment;
80
+
81
+ using Index = typename Layout::Index;
82
+ using LongIndex = typename Layout::LongIndex;
83
+ using StrideIndex = typename Layout::Stride::Index;
84
+
85
+ using TensorRef = TensorRef<Element, Layout>;
86
+ using TensorCoord = typename Layout::TensorCoord;
87
+
88
+ using ThreadMap = ThreadMap_;
89
+
90
+ static_assert(ThreadMap::kThreads / 32 > 1,
91
+ "This tile iterator requires at least two warps.");
92
+
93
+ /// Internal details made public to facilitate introspection
94
+ struct Detail {
95
+ /// This iterator is specialized for an access size that is 128 bits in
96
+ /// length.
97
+ static int const kAccessSizeInBits = 64;
98
+
99
+ static_assert(sizeof_bits<Element_>::value *
100
+ ThreadMap::kElementsPerAccess ==
101
+ kAccessSizeInBits,
102
+ "This iterator requires a policy whose access size is 64b");
103
+
104
+ ///< Number of pointers
105
+ static int const kPointerCount = 1;
106
+ };
107
+
108
+ /// Element type per access
109
+ using AccessType = Array<Element, Layout::kElementsPerAccess>;
110
+
111
+ private:
112
+ //
113
+ // Data members
114
+ //
115
+
116
+ /// Stride value
117
+ StrideIndex stride_;
118
+
119
+ /// Internal pointer to first access of tile
120
+ AccessType *pointer_;
121
+
122
+ /// Internal byte offset
123
+ Index byte_offset_;
124
+
125
+ /// Iteration in the contiguous dimension
126
+ int iteration_contiguous_;
127
+
128
+ /// Iteration in the strided dimension
129
+ int iteration_strided_;
130
+
131
+ public:
132
+
133
+ /// Construct a TileIterator with zero threadblock offset
134
+ CUTLASS_HOST_DEVICE
135
+ RegularTileAccessIterator(
136
+ TensorRef ref, ///< Pointer to start of tensor
137
+ int thread_id ///< ID of each participating thread
138
+ ):
139
+ stride_(ref.stride(0) / Layout::kElementsPerAccess),
140
+ byte_offset_(0) {
141
+
142
+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id);
143
+
144
+ // This is the offset of a thread within a threadblock tile for a specific
145
+ // pointer (units of elements)
146
+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base;
147
+
148
+ // initialize pointer
149
+ pointer_ = reinterpret_cast<AccessType *>(ref.data() + ref.offset(thread_offset_in_threadblock_tile));
150
+
151
+ set_iteration_index(0);
152
+ }
153
+
154
+ /// Overrides the internal iteration index
155
+ CUTLASS_HOST_DEVICE
156
+ void set_iteration_index(int index) {
157
+
158
+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
159
+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
160
+ }
161
+
162
+ /// Adds a pointer offset in units of Element
163
+ CUTLASS_HOST_DEVICE
164
+ void add_pointer_offset(LongIndex pointer_offset) {
165
+
166
+ byte_offset_ += pointer_offset * sizeof(Element);
167
+ }
168
+
169
+ /// Returns a pointer
170
+ CUTLASS_HOST_DEVICE
171
+ AccessType *get() const {
172
+
173
+ AccessType *access_ptr = pointer_;
174
+
175
+ int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ +
176
+ iteration_contiguous_ * ThreadMap::Delta::kContiguous /
177
+ ThreadMap::kElementsPerAccess;
178
+
179
+ char *access_byte_ptr =
180
+ reinterpret_cast<char *>(access_ptr + access_offset);
181
+
182
+ return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
183
+ }
184
+
185
+ /// Advances to the next tile in memory.
186
+ CUTLASS_HOST_DEVICE
187
+ RegularTileAccessIterator &operator++() {
188
+ ++iteration_contiguous_;
189
+
190
+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
191
+ return *this;
192
+
193
+ // Enter here only if (iteration_contiguous_ ==
194
+ // ThreadMap::Iteration::kContiguous)
195
+ iteration_contiguous_ = 0;
196
+ ++iteration_strided_;
197
+
198
+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
199
+ return *this;
200
+ }
201
+
202
+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
203
+ // which means we enter the next tile.
204
+ iteration_strided_ = 0;
205
+
206
+ return *this;
207
+ }
208
+
209
+ /// Advances to the next tile in memory.
210
+ CUTLASS_HOST_DEVICE
211
+ RegularTileAccessIterator operator++(int) {
212
+
213
+ RegularTileAccessIterator prev(*this);
214
+
215
+ this->operator++();
216
+
217
+ return prev;
218
+ }
219
+
220
+ /// Adds a tile offset
221
+ CUTLASS_DEVICE
222
+ void add_tile_offset(TensorCoord const &coord) {
223
+
224
+ add_pointer_offset(
225
+ coord.contiguous() * Shape::kContiguous +
226
+ coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess);
227
+ }
228
+ };
229
+
230
+ ////////////////////////////////////////////////////////////////////////////////
231
+
232
+ /// Tile Iterator specialized for column-major congruous TensorOp formats.
233
+ ///
234
+ ///
235
+ /// Satisfies: ForwardTileIteratorConcept |
236
+ /// ReadableContiguousTileIteratorConcept |
237
+ /// WriteableContiguousTileIteratorConcept
238
+ ///
239
+ template <typename Shape_, typename Element_, int AdvanceRank,
240
+ typename ThreadMap_, int Alignment>
241
+ class RegularTileAccessIterator<
242
+ Shape_, Element_,
243
+ layout::ColumnMajorTensorOpMultiplicandCongruous64b,
244
+ AdvanceRank, ThreadMap_, Alignment> {
245
+ public:
246
+ static_assert(
247
+ AdvanceRank == 0 || AdvanceRank == 1,
248
+ "Specialization for column-major iterator may along advance along the "
249
+ "columns(rank=0) or rows(rank=1) dimension.");
250
+
251
+ using Shape = Shape_;
252
+ using Element = Element_;
253
+ using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous64b;
254
+ static int const kAdvanceRank = AdvanceRank;
255
+ static int const kAlignment = Alignment;
256
+
257
+ using Index = typename Layout::Index;
258
+ using LongIndex = typename Layout::LongIndex;
259
+
260
+ using TensorRef = TensorRef<Element, Layout>;
261
+ using TensorCoord = typename Layout::TensorCoord;
262
+
263
+ using ThreadMap = ThreadMap_;
264
+
265
+ /// Underlying iterator type
266
+ using UnderlyingIterator = RegularTileAccessIterator<
267
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
268
+ layout::TensorOpMultiplicandCongruous64b,
269
+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>;
270
+
271
+ using AccessType = typename UnderlyingIterator::AccessType;
272
+
273
+ private:
274
+ /// Underlying iterator
275
+ UnderlyingIterator iterator_;
276
+
277
+ public:
278
+ /// Construct a TileIterator with zero threadblock offset
279
+ CUTLASS_HOST_DEVICE
280
+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
281
+ int thread_id ///< ID of each participating thread
282
+ )
283
+ : iterator_({ref.data(), ref.stride()}, thread_id) {}
284
+
285
+ /// Overrides the internal iteration index
286
+ CUTLASS_HOST_DEVICE
287
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
288
+
289
+ /// Adds a pointer offset in units of Element
290
+ CUTLASS_HOST_DEVICE
291
+ void add_pointer_offset(LongIndex pointer_offset) {
292
+ iterator_.add_pointer_offset(pointer_offset);
293
+ }
294
+
295
+ /// Returns a pointer
296
+ CUTLASS_HOST_DEVICE
297
+ AccessType *get() const {
298
+ return reinterpret_cast<AccessType *>(iterator_.get());
299
+ }
300
+
301
+ /// Adds a tile offset
302
+ CUTLASS_DEVICE
303
+ void add_tile_offset(TensorCoord const &coord) {
304
+ iterator_.add_tile_offset({coord.row(), coord.column()});
305
+ }
306
+
307
+ /// Advances to the next tile in memory.
308
+ CUTLASS_HOST_DEVICE
309
+ RegularTileAccessIterator &operator++() {
310
+ ++iterator_;
311
+ return *this;
312
+ }
313
+
314
+ /// Advances to the next tile in memory.
315
+ CUTLASS_HOST_DEVICE
316
+ RegularTileAccessIterator operator++(int) {
317
+ RegularTileAccessIterator prev(*this);
318
+ ++iterator_;
319
+
320
+ return prev;
321
+ }
322
+ };
323
+
324
+ ////////////////////////////////////////////////////////////////////////////////
325
+
326
+ /// Tile Iterator specialized for row-major congruous TensorOp formats.
327
+ ///
328
+ ///
329
+ /// Satisfies: ForwardTileIteratorConcept |
330
+ /// ReadableContiguousTileIteratorConcept |
331
+ /// WriteableContiguousTileIteratorConcept
332
+ ///
333
+ template <typename Shape_, typename Element_, int AdvanceRank,
334
+ typename ThreadMap_, int Alignment>
335
+ class RegularTileAccessIterator<Shape_, Element_,
336
+ layout::RowMajorTensorOpMultiplicandCongruous64b,
337
+ AdvanceRank, ThreadMap_, Alignment> {
338
+ public:
339
+ static_assert(
340
+ AdvanceRank == 0 || AdvanceRank == 1,
341
+ "Specialization for row-major iterator may along advance along the "
342
+ "columns(rank=0) or rows(rank=1) dimension.");
343
+
344
+ using Shape = Shape_;
345
+ using Element = Element_;
346
+ using Layout = layout::RowMajorTensorOpMultiplicandCongruous64b;
347
+ static int const kAdvanceRank = AdvanceRank;
348
+ static int const kAlignment = Alignment;
349
+
350
+ using Index = typename Layout::Index;
351
+ using LongIndex = typename Layout::LongIndex;
352
+
353
+ using TensorRef = TensorRef<Element, Layout>;
354
+ using TensorCoord = typename Layout::TensorCoord;
355
+
356
+ using ThreadMap = ThreadMap_;
357
+
358
+ /// Underlying iterator type
359
+ using UnderlyingIterator = RegularTileAccessIterator<
360
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
361
+ layout::TensorOpMultiplicandCongruous64b,
362
+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>;
363
+
364
+ using AccessType = typename UnderlyingIterator::AccessType;
365
+
366
+ private:
367
+ /// Underlying iterator
368
+ UnderlyingIterator iterator_;
369
+
370
+ public:
371
+ /// Construct a TileIterator with zero threadblock offset
372
+ CUTLASS_HOST_DEVICE
373
+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
374
+ int thread_id ///< ID of each participating thread
375
+ )
376
+ : iterator_({ref.data(), ref.stride()}, thread_id) {}
377
+
378
+ /// Overrides the internal iteration index
379
+ CUTLASS_HOST_DEVICE
380
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
381
+
382
+ /// Adds a pointer offset in units of Element
383
+ CUTLASS_HOST_DEVICE
384
+ void add_pointer_offset(LongIndex pointer_offset) {
385
+ iterator_.add_pointer_offset(pointer_offset);
386
+ }
387
+
388
+ /// Returns a pointer
389
+ CUTLASS_HOST_DEVICE
390
+ AccessType *get() const {
391
+ return reinterpret_cast<AccessType *>(iterator_.get());
392
+ }
393
+
394
+ /// Adds a tile offset
395
+ CUTLASS_DEVICE
396
+ void add_tile_offset(TensorCoord const &coord) {
397
+ iterator_.add_tile_offset({coord.column(), coord.row()});
398
+ }
399
+
400
+ /// Advances to the next tile in memory.
401
+ CUTLASS_HOST_DEVICE
402
+ RegularTileAccessIterator &operator++() {
403
+ ++iterator_;
404
+ return *this;
405
+ }
406
+
407
+ /// Advances to the next tile in memory.
408
+ CUTLASS_HOST_DEVICE
409
+ RegularTileAccessIterator operator++(int) {
410
+ RegularTileAccessIterator prev(*this);
411
+ ++iterator_;
412
+
413
+ return prev;
414
+ }
415
+ };
416
+
417
+ ////////////////////////////////////////////////////////////////////////////////
418
+ ////////////////////////////////////////////////////////////////////////////////
419
+
420
+ /// Tile iterator specialized for crosswise arrangements for TensorOps
421
+ ///
422
+ ///
423
+ /// Satisfies: ForwardTileIteratorConcept |
424
+ /// ReadableContiguousTileIteratorConcept |
425
+ /// WriteableContiguousTileIteratorConcept
426
+ ///
427
+ template <typename Shape_, typename Element_, int AdvanceRank,
428
+ typename ThreadMap_, int Alignment>
429
+ class RegularTileAccessIterator<
430
+ Shape_, Element_,
431
+ layout::TensorOpMultiplicand64bCrosswise,
432
+ AdvanceRank, ThreadMap_, Alignment> {
433
+ public:
434
+ static_assert(
435
+ AdvanceRank == 0 || AdvanceRank == 1,
436
+ "Specialization for pitch-linear iterator may along advance along the "
437
+ "contiguous(rank=0) or strided(rank=1) dimension.");
438
+
439
+ using Shape = Shape_;
440
+ using Element = Element_;
441
+ using Layout = layout::TensorOpMultiplicand64bCrosswise;
442
+ static int const kAdvanceRank = AdvanceRank;
443
+ static int const kAlignment = Alignment;
444
+
445
+ using Index = typename Layout::Index;
446
+ using LongIndex = typename Layout::LongIndex;
447
+ using StrideIndex = typename Layout::Stride::Index;
448
+
449
+ using TensorRef = TensorRef<Element, Layout>;
450
+ using TensorCoord = typename Layout::TensorCoord;
451
+
452
+ using ThreadMap = ThreadMap_;
453
+
454
+ static_assert(ThreadMap::kThreads / 32 > 1,
455
+ "This tile iterator requires at least two warps.");
456
+
457
+ /// Internal details made public to facilitate introspection
458
+ struct Detail {
459
+ /// This iterator is specialized for an access size that is 128 bits in
460
+ /// length.
461
+ static int const kAccessSizeInBits = 64;
462
+
463
+ static_assert(sizeof_bits<Element_>::value *
464
+ ThreadMap::kElementsPerAccess ==
465
+ kAccessSizeInBits,
466
+ "This iterator requires a policy whose access size is 64b");
467
+
468
+ ///< Number of pointers - two pointers are needed if making more than 4 iterations along
469
+ ///< strided dimension
470
+ static int const kPointerCount = (ThreadMap::Iterations::kStrided > 4 ? 2 : 1);
471
+ };
472
+
473
+ /// Element type per access
474
+ using AccessType = Array<Element, Layout::kElementsPerAccess>;
475
+
476
+ private:
477
+ //
478
+ // Data members
479
+ //
480
+
481
+ /// Stride value
482
+ StrideIndex stride_;
483
+
484
+ /// Internal pointer to first access of tile
485
+ AccessType *pointer_;
486
+
487
+ /// Internal byte offset
488
+ Index byte_offset_[Detail::kPointerCount];
489
+
490
+ /// Iteration in the contiguous dimension
491
+ int iteration_contiguous_;
492
+
493
+ /// Iteration in the strided dimension
494
+ int iteration_strided_;
495
+
496
+ public:
497
+
498
+ /// Construct a TileIterator with zero threadblock offset
499
+ CUTLASS_DEVICE
500
+ RegularTileAccessIterator(
501
+ TensorRef ref, ///< Pointer to start of tensor
502
+ int thread_id ///< ID of each participating thread
503
+ ):
504
+ stride_(ref.stride(0) / ThreadMap::kElementsPerAccess) {
505
+
506
+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id);
507
+
508
+ // This is the offset of a thread within a threadblock tile for a specific
509
+ // pointer (units of elements)
510
+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base;
511
+
512
+ // initialize pointer
513
+ pointer_ = reinterpret_cast<AccessType *>(ref.data());
514
+
515
+ byte_offset_[0] = ref.offset(thread_offset_in_threadblock_tile) * sizeof(Element);
516
+
517
+ if (Detail::kPointerCount == 2) {
518
+ byte_offset_[1] = byte_offset_[0] ^ 8;
519
+ }
520
+
521
+ set_iteration_index(0);
522
+ }
523
+
524
+ /// Overrides the internal iteration index
525
+ CUTLASS_HOST_DEVICE
526
+ void set_iteration_index(int index) {
527
+
528
+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
529
+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
530
+ }
531
+
532
+ /// Adds a pointer offset in units of Element
533
+ CUTLASS_HOST_DEVICE
534
+ void add_pointer_offset(LongIndex pointer_offset) {
535
+
536
+ pointer_ += pointer_offset / ThreadMap::kElementsPerAccess;
537
+ }
538
+
539
+ /// Returns a pointer
540
+ CUTLASS_DEVICE
541
+ AccessType *get() const {
542
+
543
+ // Map the logical contiguous and strided access to the internal swizzled structure.
544
+ int uniform_offset = (iteration_strided_ & 0x3) * stride_ + (iteration_strided_ >> 3) * 16 + stride_ * ThreadMap::Delta::kContiguous * iteration_contiguous_;
545
+
546
+ char *access_byte_ptr = reinterpret_cast<char *>(pointer_ + uniform_offset);
547
+
548
+ int byte_offset;
549
+
550
+ // This iterator may require two byte offsets if it must load more than 8 rows (or 2 iterations)
551
+ // in the strided dimension
552
+ if (Detail::kPointerCount == 2 && (iteration_strided_ & 0x4)) {
553
+ byte_offset = byte_offset_[1];
554
+ }
555
+ else {
556
+ byte_offset = byte_offset_[0];
557
+ }
558
+
559
+ return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset);
560
+ }
561
+
562
+ /// Advances to the next tile in memory.
563
+ CUTLASS_HOST_DEVICE
564
+ RegularTileAccessIterator &operator++() {
565
+ ++iteration_contiguous_;
566
+
567
+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
568
+ return *this;
569
+
570
+ // Enter here only if (iteration_contiguous_ ==
571
+ // ThreadMap::Iteration::kContiguous)
572
+ iteration_contiguous_ = 0;
573
+ ++iteration_strided_;
574
+
575
+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
576
+ return *this;
577
+ }
578
+
579
+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
580
+ // which means we enter the next tile.
581
+ iteration_strided_ = 0;
582
+
583
+ return *this;
584
+ }
585
+
586
+ /// Advances to the next tile in memory.
587
+ CUTLASS_HOST_DEVICE
588
+ RegularTileAccessIterator operator++(int) {
589
+
590
+ RegularTileAccessIterator prev(*this);
591
+
592
+ this->operator++();
593
+
594
+ return prev;
595
+ }
596
+
597
+ /// Adds a tile offset
598
+ CUTLASS_DEVICE
599
+ void add_tile_offset(TensorCoord const &coord) {
600
+
601
+ add_pointer_offset(coord.strided() * Shape::kStrided + coord.contiguous() * Shape::kContiguous * stride_);
602
+ }
603
+ };
604
+
605
+ ////////////////////////////////////////////////////////////////////////////////
606
+
607
+ /// Tile Iterator specialized for column-major crosswise TensorOp formats.
608
+ ///
609
+ ///
610
+ /// Satisfies: ForwardTileIteratorConcept |
611
+ /// ReadableContiguousTileIteratorConcept |
612
+ /// WriteableContiguousTileIteratorConcept
613
+ ///
614
+ template <typename Shape_, typename Element_, int AdvanceRank,
615
+ typename ThreadMap_, int Alignment>
616
+ class RegularTileAccessIterator<
617
+ Shape_, Element_,
618
+ layout::ColumnMajorTensorOpMultiplicand64bCrosswise,
619
+ AdvanceRank, ThreadMap_, Alignment> {
620
+ public:
621
+ static_assert(
622
+ AdvanceRank == 0 || AdvanceRank == 1,
623
+ "Specialization for column-major iterator may along advance along the "
624
+ "columns(rank=0) or rows(rank=1) dimension.");
625
+
626
+ using Shape = Shape_;
627
+ using Element = Element_;
628
+ using Layout = layout::ColumnMajorTensorOpMultiplicand64bCrosswise;
629
+ static int const kAdvanceRank = AdvanceRank;
630
+ static int const kAlignment = Alignment;
631
+
632
+ using Index = typename Layout::Index;
633
+ using LongIndex = typename Layout::LongIndex;
634
+
635
+ using TensorRef = TensorRef<Element, Layout>;
636
+ using TensorCoord = typename Layout::TensorCoord;
637
+
638
+ using ThreadMap = ThreadMap_;
639
+
640
+ /// Underlying iterator type
641
+ using UnderlyingIterator = RegularTileAccessIterator<
642
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
643
+ layout::TensorOpMultiplicand64bCrosswise,
644
+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>;
645
+
646
+ using AccessType = typename UnderlyingIterator::AccessType;
647
+
648
+ private:
649
+ /// Underlying iterator
650
+ UnderlyingIterator iterator_;
651
+
652
+ public:
653
+ /// Construct a TileIterator with zero threadblock offset
654
+ CUTLASS_HOST_DEVICE
655
+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
656
+ int thread_id ///< ID of each participating thread
657
+ )
658
+ : iterator_({ref.data(), ref.stride()}, thread_id) {}
659
+
660
+ /// Overrides the internal iteration index
661
+ CUTLASS_HOST_DEVICE
662
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
663
+
664
+ /// Adds a pointer offset in units of Element
665
+ CUTLASS_HOST_DEVICE
666
+ void add_pointer_offset(LongIndex pointer_offset) {
667
+ iterator_.add_pointer_offset(pointer_offset);
668
+ }
669
+
670
+ /// Returns a pointer
671
+ CUTLASS_HOST_DEVICE
672
+ AccessType *get() const {
673
+ return reinterpret_cast<AccessType *>(iterator_.get());
674
+ }
675
+
676
+ /// Adds a tile offset
677
+ CUTLASS_DEVICE
678
+ void add_tile_offset(TensorCoord const &coord) {
679
+ iterator_.add_tile_offset({coord.row(), coord.column()});
680
+ }
681
+
682
+ /// Advances to the next tile in memory.
683
+ CUTLASS_HOST_DEVICE
684
+ RegularTileAccessIterator &operator++() {
685
+ ++iterator_;
686
+ return *this;
687
+ }
688
+
689
+ /// Advances to the next tile in memory.
690
+ CUTLASS_HOST_DEVICE
691
+ RegularTileAccessIterator operator++(int) {
692
+ RegularTileAccessIterator prev(*this);
693
+ ++iterator_;
694
+
695
+ return prev;
696
+ }
697
+ };
698
+
699
+ ////////////////////////////////////////////////////////////////////////////////
700
+
701
+ /// Tile Iterator specialized for row-major crosswise TensorOp formats.
702
+ ///
703
+ ///
704
+ /// Satisfies: ForwardTileIteratorConcept |
705
+ /// ReadableContiguousTileIteratorConcept |
706
+ /// WriteableContiguousTileIteratorConcept
707
+ ///
708
+ template <typename Shape_, typename Element_, int AdvanceRank,
709
+ typename ThreadMap_, int Alignment>
710
+ class RegularTileAccessIterator<Shape_, Element_,
711
+ layout::RowMajorTensorOpMultiplicand64bCrosswise,
712
+ AdvanceRank, ThreadMap_, Alignment> {
713
+ public:
714
+ static_assert(
715
+ AdvanceRank == 0 || AdvanceRank == 1,
716
+ "Specialization for row-major iterator may along advance along the "
717
+ "columns(rank=0) or rows(rank=1) dimension.");
718
+
719
+ using Shape = Shape_;
720
+ using Element = Element_;
721
+ using Layout = layout::RowMajorTensorOpMultiplicand64bCrosswise;
722
+ static int const kAdvanceRank = AdvanceRank;
723
+ static int const kAlignment = Alignment;
724
+
725
+ using Index = typename Layout::Index;
726
+ using LongIndex = typename Layout::LongIndex;
727
+
728
+ using TensorRef = TensorRef<Element, Layout>;
729
+ using TensorCoord = typename Layout::TensorCoord;
730
+
731
+ using ThreadMap = ThreadMap_;
732
+
733
+ /// Underlying iterator type
734
+ using UnderlyingIterator = RegularTileAccessIterator<
735
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
736
+ layout::TensorOpMultiplicand64bCrosswise,
737
+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>;
738
+
739
+ using AccessType = typename UnderlyingIterator::AccessType;
740
+
741
+ private:
742
+ /// Underlying iterator
743
+ UnderlyingIterator iterator_;
744
+
745
+ public:
746
+ /// Construct a TileIterator with zero threadblock offset
747
+ CUTLASS_HOST_DEVICE
748
+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
749
+ int thread_id ///< ID of each participating thread
750
+ )
751
+ : iterator_({ref.data(), ref.stride()}, thread_id) {}
752
+
753
+ /// Overrides the internal iteration index
754
+ CUTLASS_HOST_DEVICE
755
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
756
+
757
+ /// Adds a pointer offset in units of Element
758
+ CUTLASS_HOST_DEVICE
759
+ void add_pointer_offset(LongIndex pointer_offset) {
760
+ iterator_.add_pointer_offset(pointer_offset);
761
+ }
762
+
763
+ /// Returns a pointer
764
+ CUTLASS_HOST_DEVICE
765
+ AccessType *get() const {
766
+ return reinterpret_cast<AccessType *>(iterator_.get());
767
+ }
768
+
769
+ /// Adds a tile offset
770
+ CUTLASS_DEVICE
771
+ void add_tile_offset(TensorCoord const &coord) {
772
+ iterator_.add_tile_offset({coord.column(), coord.row()});
773
+ }
774
+
775
+ /// Advances to the next tile in memory.
776
+ CUTLASS_HOST_DEVICE
777
+ RegularTileAccessIterator &operator++() {
778
+ ++iterator_;
779
+ return *this;
780
+ }
781
+
782
+ /// Advances to the next tile in memory.
783
+ CUTLASS_HOST_DEVICE
784
+ RegularTileAccessIterator operator++(int) {
785
+ RegularTileAccessIterator prev(*this);
786
+ ++iterator_;
787
+
788
+ return prev;
789
+ }
790
+ };
791
+
792
+ /////////////////////////////////////////////////////////////////////////////////////////////////
793
+ /////////////////////////////////////////////////////////////////////////////////////////////////
794
+
795
+ /// Tile iterator specialized for congruous arrangements for TensorOps
796
+ ///
797
+ ///
798
+ /// Satisfies: ForwardTileIteratorConcept |
799
+ /// ReadableContiguousTileIteratorConcept |
800
+ /// WriteableContiguousTileIteratorConcept
801
+ ///
802
+ template <typename Shape_, typename Element_, int AdvanceRank,
803
+ typename ThreadMap_, int Alignment>
804
+ class RegularTileAccessIterator<
805
+ Shape_, Element_,
806
+ layout::TensorOpMultiplicandCongruous128b,
807
+ AdvanceRank, ThreadMap_, Alignment> {
808
+ public:
809
+ static_assert(
810
+ AdvanceRank == 0 || AdvanceRank == 1,
811
+ "Specialization for pitch-linear iterator may along advance along the "
812
+ "contiguous(rank=0) or strided(rank=1) dimension.");
813
+
814
+ using Shape = Shape_;
815
+ using Element = Element_;
816
+ using Layout = layout::TensorOpMultiplicandCongruous128b;
817
+ static int const kAdvanceRank = AdvanceRank;
818
+ static int const kAlignment = Alignment;
819
+
820
+ using Index = typename Layout::Index;
821
+ using LongIndex = typename Layout::LongIndex;
822
+ using StrideIndex = typename Layout::Stride::Index;
823
+
824
+ using TensorRef = TensorRef<Element, Layout>;
825
+ using TensorCoord = typename Layout::TensorCoord;
826
+
827
+ using ThreadMap = ThreadMap_;
828
+
829
+ static_assert(ThreadMap::kThreads / 32 > 1,
830
+ "This tile iterator requires at least two warps.");
831
+
832
+ /// Internal details made public to facilitate introspection
833
+ struct Detail {
834
+ /// This iterator is specialized for an access size that is 128 bits in
835
+ /// length.
836
+ static int const kAccessSizeInBits = 128;
837
+
838
+ static_assert(sizeof_bits<Element_>::value *
839
+ ThreadMap::kElementsPerAccess ==
840
+ kAccessSizeInBits,
841
+ "This iterator requires a policy whose access size is 128b");
842
+
843
+ ///< Number of pointers
844
+ static int const kPointerCount = 1;
845
+ };
846
+
847
+ /// Element type per access
848
+ using AccessType = Array<Element, Layout::kElementsPerAccess>;
849
+
850
+ private:
851
+ //
852
+ // Data members
853
+ //
854
+
855
+ /// Stride value
856
+ StrideIndex stride_;
857
+
858
+ /// Internal pointer to first access of tile
859
+ AccessType *pointer_;
860
+
861
+ /// Internal byte offset
862
+ Index byte_offset_;
863
+
864
+ /// Iteration in the contiguous dimension
865
+ int iteration_contiguous_;
866
+
867
+ /// Iteration in the strided dimension
868
+ int iteration_strided_;
869
+
870
+ public:
871
+
872
+ /// Construct a TileIterator with zero threadblock offset
873
+ CUTLASS_HOST_DEVICE
874
+ RegularTileAccessIterator(
875
+ TensorRef ref, ///< Pointer to start of tensor
876
+ int thread_id ///< ID of each participating thread
877
+ ):
878
+ stride_(ref.stride(0) / Layout::kElementsPerAccess),
879
+ byte_offset_(0) {
880
+
881
+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id);
882
+
883
+ // This is the offset of a thread within a threadblock tile for a specific
884
+ // pointer (units of elements)
885
+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base;
886
+
887
+ // initialize pointer
888
+ pointer_ = reinterpret_cast<AccessType *>(ref.data() + ref.offset(thread_offset_in_threadblock_tile));
889
+
890
+ set_iteration_index(0);
891
+ }
892
+
893
+ /// Overrides the internal iteration index
894
+ CUTLASS_HOST_DEVICE
895
+ void set_iteration_index(int index) {
896
+
897
+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
898
+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
899
+ }
900
+
901
+ /// Adds a pointer offset in units of Element
902
+ CUTLASS_HOST_DEVICE
903
+ void add_pointer_offset(LongIndex pointer_offset) {
904
+
905
+ byte_offset_ += pointer_offset * sizeof(Element);
906
+ }
907
+
908
+ /// Returns a pointer
909
+ CUTLASS_HOST_DEVICE
910
+ AccessType *get() const {
911
+
912
+ AccessType *access_ptr = pointer_;
913
+
914
+ int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ +
915
+ iteration_contiguous_ * ThreadMap::Delta::kContiguous /
916
+ ThreadMap::kElementsPerAccess;
917
+
918
+ char *access_byte_ptr =
919
+ reinterpret_cast<char *>(access_ptr + access_offset);
920
+
921
+ return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
922
+ }
923
+
924
+ /// Advances to the next tile in memory.
925
+ CUTLASS_HOST_DEVICE
926
+ RegularTileAccessIterator &operator++() {
927
+ ++iteration_contiguous_;
928
+
929
+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
930
+ return *this;
931
+
932
+ // Enter here only if (iteration_contiguous_ ==
933
+ // ThreadMap::Iteration::kContiguous)
934
+ iteration_contiguous_ = 0;
935
+ ++iteration_strided_;
936
+
937
+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
938
+ return *this;
939
+ }
940
+
941
+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
942
+ // which means we enter the next tile.
943
+ iteration_strided_ = 0;
944
+
945
+ return *this;
946
+ }
947
+
948
+ /// Advances to the next tile in memory.
949
+ CUTLASS_HOST_DEVICE
950
+ RegularTileAccessIterator operator++(int) {
951
+
952
+ RegularTileAccessIterator prev(*this);
953
+
954
+ this->operator++();
955
+
956
+ return prev;
957
+ }
958
+
959
+ /// Adds a tile offset
960
+ CUTLASS_DEVICE
961
+ void add_tile_offset(TensorCoord const &coord) {
962
+
963
+ add_pointer_offset(
964
+ coord.contiguous() * Shape::kContiguous +
965
+ coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess);
966
+ }
967
+ };
968
+
969
+ ////////////////////////////////////////////////////////////////////////////////
970
+
971
+ /// Tile Iterator specialized for column-major congruous TensorOp formats.
972
+ ///
973
+ ///
974
+ /// Satisfies: ForwardTileIteratorConcept |
975
+ /// ReadableContiguousTileIteratorConcept |
976
+ /// WriteableContiguousTileIteratorConcept
977
+ ///
978
+ template <typename Shape_, typename Element_, int AdvanceRank,
979
+ typename ThreadMap_, int Alignment>
980
+ class RegularTileAccessIterator<
981
+ Shape_, Element_,
982
+ layout::ColumnMajorTensorOpMultiplicandCongruous128b,
983
+ AdvanceRank, ThreadMap_, Alignment> {
984
+ public:
985
+ static_assert(
986
+ AdvanceRank == 0 || AdvanceRank == 1,
987
+ "Specialization for column-major iterator may along advance along the "
988
+ "columns(rank=0) or rows(rank=1) dimension.");
989
+
990
+ using Shape = Shape_;
991
+ using Element = Element_;
992
+ using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous128b;
993
+ static int const kAdvanceRank = AdvanceRank;
994
+ static int const kAlignment = Alignment;
995
+
996
+ using Index = typename Layout::Index;
997
+ using LongIndex = typename Layout::LongIndex;
998
+
999
+ using TensorRef = TensorRef<Element, Layout>;
1000
+ using TensorCoord = typename Layout::TensorCoord;
1001
+
1002
+ using ThreadMap = ThreadMap_;
1003
+
1004
+ /// Underlying iterator type
1005
+ using UnderlyingIterator = RegularTileAccessIterator<
1006
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
1007
+ layout::TensorOpMultiplicandCongruous128b,
1008
+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>;
1009
+
1010
+ using AccessType = typename UnderlyingIterator::AccessType;
1011
+
1012
+ private:
1013
+ /// Underlying iterator
1014
+ UnderlyingIterator iterator_;
1015
+
1016
+ public:
1017
+ /// Construct a TileIterator with zero threadblock offset
1018
+ CUTLASS_HOST_DEVICE
1019
+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
1020
+ int thread_id ///< ID of each participating thread
1021
+ )
1022
+ : iterator_({ref.data(), ref.stride()}, thread_id) {}
1023
+
1024
+ /// Overrides the internal iteration index
1025
+ CUTLASS_HOST_DEVICE
1026
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
1027
+
1028
+ /// Adds a pointer offset in units of Element
1029
+ CUTLASS_HOST_DEVICE
1030
+ void add_pointer_offset(LongIndex pointer_offset) {
1031
+ iterator_.add_pointer_offset(pointer_offset);
1032
+ }
1033
+
1034
+ /// Returns a pointer
1035
+ CUTLASS_HOST_DEVICE
1036
+ AccessType *get() const {
1037
+ return reinterpret_cast<AccessType *>(iterator_.get());
1038
+ }
1039
+
1040
+ /// Adds a tile offset
1041
+ CUTLASS_DEVICE
1042
+ void add_tile_offset(TensorCoord const &coord) {
1043
+ iterator_.add_tile_offset({coord.row(), coord.column()});
1044
+ }
1045
+
1046
+ /// Advances to the next tile in memory.
1047
+ CUTLASS_HOST_DEVICE
1048
+ RegularTileAccessIterator &operator++() {
1049
+ ++iterator_;
1050
+ return *this;
1051
+ }
1052
+
1053
+ /// Advances to the next tile in memory.
1054
+ CUTLASS_HOST_DEVICE
1055
+ RegularTileAccessIterator operator++(int) {
1056
+ RegularTileAccessIterator prev(*this);
1057
+ ++iterator_;
1058
+
1059
+ return prev;
1060
+ }
1061
+ };
1062
+
1063
+ ////////////////////////////////////////////////////////////////////////////////
1064
+
1065
+ /// Tile Iterator specialized for row-major congruous TensorOp formats.
1066
+ ///
1067
+ ///
1068
+ /// Satisfies: ForwardTileIteratorConcept |
1069
+ /// ReadableContiguousTileIteratorConcept |
1070
+ /// WriteableContiguousTileIteratorConcept
1071
+ ///
1072
+ template <typename Shape_, typename Element_, int AdvanceRank,
1073
+ typename ThreadMap_, int Alignment>
1074
+ class RegularTileAccessIterator<Shape_, Element_,
1075
+ layout::RowMajorTensorOpMultiplicandCongruous128b,
1076
+ AdvanceRank, ThreadMap_, Alignment> {
1077
+ public:
1078
+ static_assert(
1079
+ AdvanceRank == 0 || AdvanceRank == 1,
1080
+ "Specialization for row-major iterator may along advance along the "
1081
+ "columns(rank=0) or rows(rank=1) dimension.");
1082
+
1083
+ using Shape = Shape_;
1084
+ using Element = Element_;
1085
+ using Layout = layout::RowMajorTensorOpMultiplicandCongruous128b;
1086
+ static int const kAdvanceRank = AdvanceRank;
1087
+ static int const kAlignment = Alignment;
1088
+
1089
+ using Index = typename Layout::Index;
1090
+ using LongIndex = typename Layout::LongIndex;
1091
+
1092
+ using TensorRef = TensorRef<Element, Layout>;
1093
+ using TensorCoord = typename Layout::TensorCoord;
1094
+
1095
+ using ThreadMap = ThreadMap_;
1096
+
1097
+ /// Underlying iterator type
1098
+ using UnderlyingIterator = RegularTileAccessIterator<
1099
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
1100
+ layout::TensorOpMultiplicandCongruous128b,
1101
+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>;
1102
+
1103
+ using AccessType = typename UnderlyingIterator::AccessType;
1104
+
1105
+ private:
1106
+ /// Underlying iterator
1107
+ UnderlyingIterator iterator_;
1108
+
1109
+ public:
1110
+ /// Construct a TileIterator with zero threadblock offset
1111
+ CUTLASS_HOST_DEVICE
1112
+ RegularTileAccessIterator(
1113
+ TensorRef ref, ///< Pointer to start of tensor
1114
+ int thread_id ///< ID of each participating thread
1115
+ ):
1116
+ iterator_({ref.data(), ref.stride()}, thread_id) {}
1117
+
1118
+ /// Overrides the internal iteration index
1119
+ CUTLASS_HOST_DEVICE
1120
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
1121
+
1122
+ /// Adds a pointer offset in units of Element
1123
+ CUTLASS_HOST_DEVICE
1124
+ void add_pointer_offset(LongIndex pointer_offset) {
1125
+ iterator_.add_pointer_offset(pointer_offset);
1126
+ }
1127
+
1128
+ /// Returns a pointer
1129
+ CUTLASS_HOST_DEVICE
1130
+ AccessType *get() const {
1131
+ return reinterpret_cast<AccessType *>(iterator_.get());
1132
+ }
1133
+
1134
+ /// Adds a tile offset
1135
+ CUTLASS_DEVICE
1136
+ void add_tile_offset(TensorCoord const &coord) {
1137
+ iterator_.add_tile_offset({coord.column(), coord.row()});
1138
+ }
1139
+
1140
+ /// Advances to the next tile in memory.
1141
+ CUTLASS_HOST_DEVICE
1142
+ RegularTileAccessIterator &operator++() {
1143
+ ++iterator_;
1144
+ return *this;
1145
+ }
1146
+
1147
+ /// Advances to the next tile in memory.
1148
+ CUTLASS_HOST_DEVICE
1149
+ RegularTileAccessIterator operator++(int) {
1150
+ RegularTileAccessIterator prev(*this);
1151
+ ++iterator_;
1152
+
1153
+ return prev;
1154
+ }
1155
+ };
1156
+
1157
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1158
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1159
+
1160
+ /// Tile iterator specialized for congruous arrangements for TensorOps
1161
+ ///
1162
+ ///
1163
+ /// Satisfies: ForwardTileIteratorConcept |
1164
+ /// ReadableContiguousTileIteratorConcept |
1165
+ /// WriteableContiguousTileIteratorConcept
1166
+ ///
1167
+ template <typename Shape_, typename Element_, int AdvanceRank,
1168
+ typename ThreadMap_, int Alignment>
1169
+ class RegularTileAccessIterator<
1170
+ Shape_, Element_,
1171
+ layout::TensorOpMultiplicandCrosswise128x4,
1172
+ AdvanceRank, ThreadMap_, Alignment> {
1173
+ public:
1174
+ static_assert(
1175
+ AdvanceRank == 0 || AdvanceRank == 1,
1176
+ "Specialization for pitch-linear iterator may along advance along the "
1177
+ "contiguous(rank=0) or strided(rank=1) dimension.");
1178
+
1179
+ using Shape = Shape_;
1180
+ using Element = Element_;
1181
+ using Layout = layout::TensorOpMultiplicandCrosswise128x4;
1182
+ static int const kAdvanceRank = AdvanceRank;
1183
+ static int const kAlignment = Alignment;
1184
+
1185
+ using Index = typename Layout::Index;
1186
+ using LongIndex = typename Layout::LongIndex;
1187
+ using StrideIndex = typename Layout::Stride::Index;
1188
+
1189
+ using TensorRef = TensorRef<Element, Layout>;
1190
+ using TensorCoord = typename Layout::TensorCoord;
1191
+
1192
+ using ThreadMap = ThreadMap_;
1193
+
1194
+ static_assert(ThreadMap::kThreads / 32 > 1,
1195
+ "This tile iterator requires at least two warps.");
1196
+
1197
+ /// Internal details made public to facilitate introspection
1198
+ struct Detail {
1199
+ /// This iterator is specialized for an access size that is 128 bits in
1200
+ /// length.
1201
+ static int const kAccessSizeInBits = 128;
1202
+
1203
+ static_assert(sizeof_bits<Element_>::value *
1204
+ ThreadMap::kElementsPerAccess ==
1205
+ kAccessSizeInBits,
1206
+ "This iterator requires a policy whose access size is 128b");
1207
+
1208
+ ///< Number of pointers
1209
+ static int const kPointerCount = 1;
1210
+ };
1211
+
1212
+
1213
+ static_assert(!(ThreadMap::Iterations::kStrided % 2), "This iterator requires at least two iterations along the strided dimension");
1214
+
1215
+ /// Element type per access
1216
+ using AccessType = Array<Element, Layout::kElementsPerAccess>;
1217
+
1218
+ private:
1219
+ //
1220
+ // Data members
1221
+ //
1222
+
1223
+ /// Stride value
1224
+ StrideIndex stride_;
1225
+
1226
+ /// Internal pointer to first access of tile
1227
+ AccessType *pointer_;
1228
+
1229
+ /// Internal byte offset
1230
+ Index byte_offset_;
1231
+
1232
+ /// Iteration in the contiguous dimension
1233
+ int iteration_contiguous_;
1234
+
1235
+ /// Iteration in the strided dimension
1236
+ int iteration_strided_;
1237
+
1238
+ public:
1239
+
1240
+ /// Construct a TileIterator with zero threadblock offset
1241
+ CUTLASS_DEVICE
1242
+ RegularTileAccessIterator(
1243
+ TensorRef ref, ///< Pointer to start of tensor
1244
+ int thread_id ///< ID of each participating thread
1245
+ ):
1246
+ stride_(ref.stride(0) / Layout::kElementsPerAccess),
1247
+ byte_offset_(0) {
1248
+
1249
+ layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id);
1250
+
1251
+ // This is the offset of a thread within a threadblock tile for a specific
1252
+ // pointer (units of elements)
1253
+ layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base;
1254
+
1255
+ // initialize pointer
1256
+ pointer_ = reinterpret_cast<AccessType *>(ref.data() + ref.offset(thread_offset_in_threadblock_tile));
1257
+
1258
+ set_iteration_index(0);
1259
+ }
1260
+
1261
+ /// Overrides the internal iteration index
1262
+ CUTLASS_HOST_DEVICE
1263
+ void set_iteration_index(int index) {
1264
+
1265
+ iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
1266
+ iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
1267
+ }
1268
+
1269
+ /// Adds a pointer offset in units of Element
1270
+ CUTLASS_HOST_DEVICE
1271
+ void add_pointer_offset(LongIndex pointer_offset) {
1272
+
1273
+ byte_offset_ += pointer_offset * sizeof(Element);
1274
+ }
1275
+
1276
+ /// Returns a pointer
1277
+ CUTLASS_HOST_DEVICE
1278
+ AccessType *get() const {
1279
+
1280
+ AccessType *access_ptr = pointer_;
1281
+
1282
+ int offset_c = (iteration_contiguous_ * ThreadMap::Delta::kContiguous + (iteration_strided_ & 1) * 2);
1283
+ int offset_s = (iteration_strided_ / 2) * 8;
1284
+
1285
+ int access_offset = offset_c * stride_ + offset_s;
1286
+
1287
+ char *access_byte_ptr =
1288
+ reinterpret_cast<char *>(access_ptr + access_offset);
1289
+
1290
+ return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
1291
+ }
1292
+
1293
+ /// Advances to the next tile in memory.
1294
+ CUTLASS_HOST_DEVICE
1295
+ RegularTileAccessIterator &operator++() {
1296
+ ++iteration_contiguous_;
1297
+
1298
+ if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
1299
+ return *this;
1300
+
1301
+ // Enter here only if (iteration_contiguous_ ==
1302
+ // ThreadMap::Iteration::kContiguous)
1303
+ iteration_contiguous_ = 0;
1304
+ ++iteration_strided_;
1305
+
1306
+ if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
1307
+ return *this;
1308
+ }
1309
+
1310
+ // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
1311
+ // which means we enter the next tile.
1312
+ iteration_strided_ = 0;
1313
+
1314
+ return *this;
1315
+ }
1316
+
1317
+ /// Advances to the next tile in memory.
1318
+ CUTLASS_HOST_DEVICE
1319
+ RegularTileAccessIterator operator++(int) {
1320
+
1321
+ RegularTileAccessIterator prev(*this);
1322
+
1323
+ this->operator++();
1324
+
1325
+ return prev;
1326
+ }
1327
+
1328
+ /// Adds a tile offset
1329
+ CUTLASS_DEVICE
1330
+ void add_tile_offset(TensorCoord const &coord) {
1331
+
1332
+ add_pointer_offset(
1333
+ coord.contiguous() * Shape::kContiguous * stride_ +
1334
+ coord.strided() * Shape::kStrided * Layout::kElementsPerAccess);
1335
+ }
1336
+ };
1337
+
1338
+ ////////////////////////////////////////////////////////////////////////////////
1339
+
1340
+ /// Tile Iterator specialized for column-major congruous TensorOp formats.
1341
+ ///
1342
+ ///
1343
+ /// Satisfies: ForwardTileIteratorConcept |
1344
+ /// ReadableContiguousTileIteratorConcept |
1345
+ /// WriteableContiguousTileIteratorConcept
1346
+ ///
1347
+ template <typename Shape_, typename Element_, int AdvanceRank,
1348
+ typename ThreadMap_, int Alignment>
1349
+ class RegularTileAccessIterator<
1350
+ Shape_, Element_,
1351
+ layout::ColumnMajorTensorOpMultiplicandCrosswise128x4,
1352
+ AdvanceRank, ThreadMap_, Alignment> {
1353
+ public:
1354
+ static_assert(
1355
+ AdvanceRank == 0 || AdvanceRank == 1,
1356
+ "Specialization for column-major iterator may along advance along the "
1357
+ "columns(rank=0) or rows(rank=1) dimension.");
1358
+
1359
+ using Shape = Shape_;
1360
+ using Element = Element_;
1361
+ using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise128x4;
1362
+ static int const kAdvanceRank = AdvanceRank;
1363
+ static int const kAlignment = Alignment;
1364
+
1365
+ using Index = typename Layout::Index;
1366
+ using LongIndex = typename Layout::LongIndex;
1367
+
1368
+ using TensorRef = TensorRef<Element, Layout>;
1369
+ using TensorCoord = typename Layout::TensorCoord;
1370
+
1371
+ using ThreadMap = ThreadMap_;
1372
+
1373
+ /// Underlying iterator type
1374
+ using UnderlyingIterator = RegularTileAccessIterator<
1375
+ layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
1376
+ layout::TensorOpMultiplicandCrosswise128x4,
1377
+ (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>;
1378
+
1379
+ using AccessType = typename UnderlyingIterator::AccessType;
1380
+
1381
+ private:
1382
+ /// Underlying iterator
1383
+ UnderlyingIterator iterator_;
1384
+
1385
+ public:
1386
+ /// Construct a TileIterator with zero threadblock offset
1387
+ CUTLASS_HOST_DEVICE
1388
+ RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
1389
+ int thread_id ///< ID of each participating thread
1390
+ )
1391
+ : iterator_({ref.data(), ref.stride()}, thread_id) {}
1392
+
1393
+ /// Overrides the internal iteration index
1394
+ CUTLASS_HOST_DEVICE
1395
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
1396
+
1397
+ /// Adds a pointer offset in units of Element
1398
+ CUTLASS_HOST_DEVICE
1399
+ void add_pointer_offset(LongIndex pointer_offset) {
1400
+ iterator_.add_pointer_offset(pointer_offset);
1401
+ }
1402
+
1403
+ /// Returns a pointer
1404
+ CUTLASS_HOST_DEVICE
1405
+ AccessType *get() const {
1406
+ return reinterpret_cast<AccessType *>(iterator_.get());
1407
+ }
1408
+
1409
+ /// Adds a tile offset
1410
+ CUTLASS_DEVICE
1411
+ void add_tile_offset(TensorCoord const &coord) {
1412
+ iterator_.add_tile_offset({coord.row(), coord.column()});
1413
+ }
1414
+
1415
+ /// Advances to the next tile in memory.
1416
+ CUTLASS_HOST_DEVICE
1417
+ RegularTileAccessIterator &operator++() {
1418
+ ++iterator_;
1419
+ return *this;
1420
+ }
1421
+
1422
+ /// Advances to the next tile in memory.
1423
+ CUTLASS_HOST_DEVICE
1424
+ RegularTileAccessIterator operator++(int) {
1425
+ RegularTileAccessIterator prev(*this);
1426
+ ++iterator_;
1427
+
1428
+ return prev;
1429
+ }
1430
+ };
1431
+
1432
+ ////////////////////////////////////////////////////////////////////////////////
1433
+
1434
+ /// Tile Iterator specialized for row-major congruous TensorOp formats.
1435
+ ///
1436
+ ///
1437
+ /// Satisfies: ForwardTileIteratorConcept |
1438
+ /// ReadableContiguousTileIteratorConcept |
1439
+ /// WriteableContiguousTileIteratorConcept
1440
+ ///
1441
+ template <typename Shape_, typename Element_, int AdvanceRank,
1442
+ typename ThreadMap_, int Alignment>
1443
+ class RegularTileAccessIterator<Shape_, Element_,
1444
+ layout::RowMajorTensorOpMultiplicandCrosswise128x4,
1445
+ AdvanceRank, ThreadMap_, Alignment> {
1446
+ public:
1447
+ static_assert(
1448
+ AdvanceRank == 0 || AdvanceRank == 1,
1449
+ "Specialization for row-major iterator may along advance along the "
1450
+ "columns(rank=0) or rows(rank=1) dimension.");
1451
+
1452
+ using Shape = Shape_;
1453
+ using Element = Element_;
1454
+ using Layout = layout::RowMajorTensorOpMultiplicandCrosswise128x4;
1455
+ static int const kAdvanceRank = AdvanceRank;
1456
+ static int const kAlignment = Alignment;
1457
+
1458
+ using Index = typename Layout::Index;
1459
+ using LongIndex = typename Layout::LongIndex;
1460
+
1461
+ using TensorRef = TensorRef<Element, Layout>;
1462
+ using TensorCoord = typename Layout::TensorCoord;
1463
+
1464
+ using ThreadMap = ThreadMap_;
1465
+
1466
+ /// Underlying iterator type
1467
+ using UnderlyingIterator = RegularTileAccessIterator<
1468
+ layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
1469
+ layout::TensorOpMultiplicandCrosswise128x4,
1470
+ (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>;
1471
+
1472
+ using AccessType = typename UnderlyingIterator::AccessType;
1473
+
1474
+ private:
1475
+ /// Underlying iterator
1476
+ UnderlyingIterator iterator_;
1477
+
1478
+ public:
1479
+ /// Construct a TileIterator with zero threadblock offset
1480
+ CUTLASS_HOST_DEVICE
1481
+ RegularTileAccessIterator(
1482
+ TensorRef ref, ///< Pointer to start of tensor
1483
+ int thread_id ///< ID of each participating thread
1484
+ ):
1485
+ iterator_({ref.data(), ref.stride()}, thread_id) {}
1486
+
1487
+ /// Overrides the internal iteration index
1488
+ CUTLASS_HOST_DEVICE
1489
+ void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
1490
+
1491
+ /// Adds a pointer offset in units of Element
1492
+ CUTLASS_HOST_DEVICE
1493
+ void add_pointer_offset(LongIndex pointer_offset) {
1494
+ iterator_.add_pointer_offset(pointer_offset);
1495
+ }
1496
+
1497
+ /// Returns a pointer
1498
+ CUTLASS_HOST_DEVICE
1499
+ AccessType *get() const {
1500
+ return reinterpret_cast<AccessType *>(iterator_.get());
1501
+ }
1502
+
1503
+ /// Adds a tile offset
1504
+ CUTLASS_DEVICE
1505
+ void add_tile_offset(TensorCoord const &coord) {
1506
+ iterator_.add_tile_offset({coord.column(), coord.row()});
1507
+ }
1508
+
1509
+ /// Advances to the next tile in memory.
1510
+ CUTLASS_HOST_DEVICE
1511
+ RegularTileAccessIterator &operator++() {
1512
+ ++iterator_;
1513
+ return *this;
1514
+ }
1515
+
1516
+ /// Advances to the next tile in memory.
1517
+ CUTLASS_HOST_DEVICE
1518
+ RegularTileAccessIterator operator++(int) {
1519
+ RegularTileAccessIterator prev(*this);
1520
+ ++iterator_;
1521
+
1522
+ return prev;
1523
+ }
1524
+ };
1525
+
1526
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1527
+
1528
+ } // namespace threadblock
1529
+ } // namespace transform
1530
+ } // namespace cutlass
1531
+
1532
+ /////////////////////////////////////////////////////////////////////////////////////////////////