Kernels:
Trusted publisher
Build uploaded using `kernels` (batch 9/10).
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/device/reduce_split_k.h +232 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce.h +264 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h +374 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h +362 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h +267 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h +248 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h +606 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h +641 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduce.h +234 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduction_operators.h +235 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/threadblock_swizzle.h +67 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/relatively_equal.h +305 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/semaphore.h +118 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/subbyte_reference.h +1388 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_coord.h +326 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_ref.h +419 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_ref_planar_complex.h +374 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_view.h +297 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_view_planar_complex.h +302 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tfloat32.h +479 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/thread/matrix.h +198 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/trace.h +59 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp +754 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/device/transform_universal_adapter.hpp +303 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/filter_format_transformer.hpp +223 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp +603 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sparse_gemm_compressor.hpp +325 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/pitch_linear_thread_map.h +926 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/thread/transpose.h +107 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/thread/unary_op.h +105 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_iterator.h +199 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h +1350 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h +1315 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h +375 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h +328 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h +2118 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h +834 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h +290 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h +892 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h +1887 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h +787 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h +818 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h +417 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h +253 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h +58 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h +408 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h +587 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h +821 -0
- build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h +1532 -0
.gitattributes
CHANGED
|
@@ -23,3 +23,4 @@ build/torch210-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=
|
|
| 23 |
build/torch29-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 24 |
build/torch29-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 25 |
build/torch29-cxx11-cu129-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 23 |
build/torch29-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 24 |
build/torch29-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 25 |
build/torch29-cxx11-cu129-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
build/torch29-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/device/reduce_split_k.h
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Kernel performing a reduction over densely packed tensors in global memory
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/device_kernel.h"
|
| 38 |
+
#include "cutlass/reduction/kernel/reduce_split_k.h"
|
| 39 |
+
#include "cutlass/cuda_host_adapter.hpp"
|
| 40 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 41 |
+
|
| 42 |
+
namespace cutlass {
|
| 43 |
+
namespace reduction {
|
| 44 |
+
namespace device {
|
| 45 |
+
|
| 46 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
template <
|
| 49 |
+
typename ReductionKernel_
|
| 50 |
+
>
|
| 51 |
+
class ReduceSplitK {
|
| 52 |
+
public:
|
| 53 |
+
using ReductionKernel = ReductionKernel_;
|
| 54 |
+
|
| 55 |
+
using Shape = typename ReductionKernel::Shape;
|
| 56 |
+
using ReductionOp = typename ReductionKernel::ReductionOp;
|
| 57 |
+
using OutputOp = typename ReductionKernel::OutputOp;
|
| 58 |
+
|
| 59 |
+
using ElementWorkspace = typename ReductionKernel::ElementWorkspace;
|
| 60 |
+
using ElementAccumulator = typename ReductionKernel::ElementAccumulator;
|
| 61 |
+
using ElementOutput = typename ReductionKernel::ElementOutput;
|
| 62 |
+
|
| 63 |
+
using WorkspaceTensorRef = typename ReductionKernel::WorkspaceTensorRef;
|
| 64 |
+
using OutputTensorRef = typename ReductionKernel::OutputTensorRef;
|
| 65 |
+
|
| 66 |
+
using StrideIndex = typename ReductionKernel::StrideIndex;
|
| 67 |
+
|
| 68 |
+
static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER;
|
| 69 |
+
|
| 70 |
+
/// Argument structure
|
| 71 |
+
struct Arguments {
|
| 72 |
+
|
| 73 |
+
//
|
| 74 |
+
// Data members
|
| 75 |
+
//
|
| 76 |
+
|
| 77 |
+
MatrixCoord problem_size{0,0};
|
| 78 |
+
int partitions{1};
|
| 79 |
+
size_t partition_stride{0};
|
| 80 |
+
WorkspaceTensorRef workspace{};
|
| 81 |
+
OutputTensorRef destination{};
|
| 82 |
+
OutputTensorRef source{};
|
| 83 |
+
typename OutputOp::Params output{};
|
| 84 |
+
typename ReductionOp::Params reduction{};
|
| 85 |
+
|
| 86 |
+
//
|
| 87 |
+
// Methods
|
| 88 |
+
//
|
| 89 |
+
|
| 90 |
+
/// Default ctor
|
| 91 |
+
Arguments() = default;
|
| 92 |
+
|
| 93 |
+
CUTLASS_HOST_DEVICE
|
| 94 |
+
Arguments(
|
| 95 |
+
MatrixCoord const & problem_size
|
| 96 |
+
):
|
| 97 |
+
problem_size(problem_size) { }
|
| 98 |
+
|
| 99 |
+
CUTLASS_HOST_DEVICE
|
| 100 |
+
Arguments(
|
| 101 |
+
MatrixCoord problem_size_,
|
| 102 |
+
int partitions_,
|
| 103 |
+
size_t partition_stride_,
|
| 104 |
+
WorkspaceTensorRef workspace_,
|
| 105 |
+
OutputTensorRef destination_,
|
| 106 |
+
OutputTensorRef source_,
|
| 107 |
+
typename OutputOp::Params output_ = typename OutputOp::Params(),
|
| 108 |
+
typename ReductionOp::Params reduction_ = typename ReductionOp::Params()
|
| 109 |
+
):
|
| 110 |
+
problem_size(problem_size_),
|
| 111 |
+
partitions(partitions_),
|
| 112 |
+
partition_stride(partition_stride_),
|
| 113 |
+
workspace(workspace_),
|
| 114 |
+
destination(destination_),
|
| 115 |
+
source(source_),
|
| 116 |
+
output(output_),
|
| 117 |
+
reduction(reduction_)
|
| 118 |
+
{
|
| 119 |
+
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
};
|
| 123 |
+
|
| 124 |
+
private:
|
| 125 |
+
/// Kernel parameters object
|
| 126 |
+
typename ReductionKernel::Params params_;
|
| 127 |
+
|
| 128 |
+
public:
|
| 129 |
+
/// Constructs Reduction SplitK
|
| 130 |
+
ReduceSplitK() { }
|
| 131 |
+
|
| 132 |
+
/// Determines whether the ReduceSplitK can execute the given problem.
|
| 133 |
+
static Status can_implement(Arguments const &args) {
|
| 134 |
+
|
| 135 |
+
return Status::kSuccess;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
/// Gets the workspace size
|
| 139 |
+
static size_t get_workspace_size(Arguments const &args) {
|
| 140 |
+
// needs no additional workspace
|
| 141 |
+
return 0;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
/// Initializes Reduction state from arguments.
|
| 145 |
+
Status initialize(
|
| 146 |
+
Arguments const &args,
|
| 147 |
+
void *workspace = nullptr,
|
| 148 |
+
cudaStream_t stream = nullptr) {
|
| 149 |
+
|
| 150 |
+
// initialize the params structure from the arguments
|
| 151 |
+
params_ = typename ReductionKernel::Params(
|
| 152 |
+
args.problem_size,
|
| 153 |
+
args.partitions,
|
| 154 |
+
args.partition_stride,
|
| 155 |
+
args.workspace,
|
| 156 |
+
args.destination,
|
| 157 |
+
args.source,
|
| 158 |
+
args.output,
|
| 159 |
+
args.reduction
|
| 160 |
+
);
|
| 161 |
+
|
| 162 |
+
return Status::kSuccess;
|
| 163 |
+
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
/// Initializes Reduction kernel state from arguments.
|
| 167 |
+
Status update(Arguments const &args, void *workspace = nullptr) {
|
| 168 |
+
|
| 169 |
+
// update the params structure from the arguments
|
| 170 |
+
params_.workspace.reset(args.workspace.non_const_ref().data());
|
| 171 |
+
params_.destination.reset(args.destination.non_const_ref().data());
|
| 172 |
+
params_.source.reset(args.source.non_const_ref().data());
|
| 173 |
+
params_.output = args.output;
|
| 174 |
+
params_.reduction = args.reduction;
|
| 175 |
+
|
| 176 |
+
return Status::kSuccess;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
/// Runs the kernel using initialized state.
|
| 180 |
+
Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) {
|
| 181 |
+
|
| 182 |
+
//
|
| 183 |
+
// Launch reduction kernel
|
| 184 |
+
//
|
| 185 |
+
dim3 block = ReductionKernel::block_shape();
|
| 186 |
+
dim3 grid = ReductionKernel::grid_shape(params_.problem_size);
|
| 187 |
+
|
| 188 |
+
if constexpr (kEnableCudaHostAdapter) {
|
| 189 |
+
CUTLASS_ASSERT(cuda_adapter);
|
| 190 |
+
if (cuda_adapter) {
|
| 191 |
+
void* kernel_params[] = {¶ms_};
|
| 192 |
+
cuda_adapter->launch(
|
| 193 |
+
grid, dim3(1,1,1), block, 0, stream, kernel_params, kernel_index);
|
| 194 |
+
}
|
| 195 |
+
}
|
| 196 |
+
else {
|
| 197 |
+
cutlass::arch::synclog_setup();
|
| 198 |
+
Kernel<ReductionKernel><<< grid, block, 0, stream >>>(params_);
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
cudaError_t result = cudaGetLastError();
|
| 202 |
+
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
/// Runs the kernel using initialized state.
|
| 207 |
+
Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) {
|
| 208 |
+
return run(stream, cuda_adapter, kernel_index);
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
/// Runs the kernel using initialized state.
|
| 212 |
+
Status operator()(
|
| 213 |
+
Arguments const &args,
|
| 214 |
+
void *workspace = nullptr,
|
| 215 |
+
cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) {
|
| 216 |
+
|
| 217 |
+
Status status = initialize(args, workspace, stream);
|
| 218 |
+
|
| 219 |
+
if (status == Status::kSuccess) {
|
| 220 |
+
status = run(stream,cuda_adapter, kernel_index);
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
return status;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
};
|
| 227 |
+
|
| 228 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 229 |
+
|
| 230 |
+
} // namespace kernel
|
| 231 |
+
} // namespace reduction
|
| 232 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce.h
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Kernel performing a reduction over one or more ranks of an affine tensor
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/array.h"
|
| 39 |
+
#include "cutlass/fast_math.h"
|
| 40 |
+
#include "cutlass/numeric_types.h"
|
| 41 |
+
#include "cutlass/numeric_conversion.h"
|
| 42 |
+
#include "cutlass/device_kernel.h"
|
| 43 |
+
|
| 44 |
+
#include "cutlass/reduction/device/tensor_reduce_affine_strided.h"
|
| 45 |
+
#include "cutlass/reduction/device/tensor_reduce_affine_contiguous.h"
|
| 46 |
+
|
| 47 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
namespace cutlass {
|
| 50 |
+
namespace reduction {
|
| 51 |
+
namespace device {
|
| 52 |
+
|
| 53 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 54 |
+
|
| 55 |
+
/// Tensor reduction operator on specific CUTLASS layouts over exactly one index
|
| 56 |
+
template <
|
| 57 |
+
typename ElementOutput_,
|
| 58 |
+
typename ElementSource_,
|
| 59 |
+
typename Layout_,
|
| 60 |
+
typename ReductionOp_,
|
| 61 |
+
int VectorLength_ = 1,
|
| 62 |
+
typename ElementCompute_ = ElementOutput_
|
| 63 |
+
>
|
| 64 |
+
struct TensorReduction {
|
| 65 |
+
|
| 66 |
+
using ElementOutput = ElementOutput_;
|
| 67 |
+
using ElementSource = ElementSource_;
|
| 68 |
+
using Layout = Layout_;
|
| 69 |
+
using ReductionOp = ReductionOp_;
|
| 70 |
+
static int const kVectorLength = VectorLength_;
|
| 71 |
+
using ElementCompute = ElementCompute_;
|
| 72 |
+
|
| 73 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 74 |
+
|
| 75 |
+
/// Reduction operator
|
| 76 |
+
using ReductionDeviceStridedOperator = TensorReductionAffineStrided<
|
| 77 |
+
4, 3, ElementOutput, ElementSource, ReductionOp, kVectorLength, ElementCompute
|
| 78 |
+
>;
|
| 79 |
+
|
| 80 |
+
using ReductionDeviceContiguousOperator = TensorReductionAffineContiguous<
|
| 81 |
+
4, 3, ElementOutput, ElementSource, ReductionOp, kVectorLength, ElementCompute
|
| 82 |
+
>;
|
| 83 |
+
|
| 84 |
+
//
|
| 85 |
+
// Data members
|
| 86 |
+
//
|
| 87 |
+
|
| 88 |
+
ReductionDeviceStridedOperator reduction_strided;
|
| 89 |
+
ReductionDeviceContiguousOperator reduction_contiguous;
|
| 90 |
+
int reduction_index;
|
| 91 |
+
|
| 92 |
+
//
|
| 93 |
+
// Methods
|
| 94 |
+
//
|
| 95 |
+
|
| 96 |
+
///
|
| 97 |
+
TensorReduction(
|
| 98 |
+
TensorCoord extent,
|
| 99 |
+
int reduction_index_
|
| 100 |
+
):
|
| 101 |
+
reduction_index(reduction_index_) {
|
| 102 |
+
|
| 103 |
+
Coord<4> extent_affine;
|
| 104 |
+
|
| 105 |
+
switch (reduction_index) {
|
| 106 |
+
case 0:
|
| 107 |
+
extent_affine[0] = extent[1];
|
| 108 |
+
extent_affine[1] = extent[2];
|
| 109 |
+
extent_affine[2] = extent[0];
|
| 110 |
+
extent_affine[3] = extent[3];
|
| 111 |
+
break;
|
| 112 |
+
case 1:
|
| 113 |
+
extent_affine[0] = extent[0];
|
| 114 |
+
extent_affine[1] = extent[2];
|
| 115 |
+
extent_affine[2] = extent[1];
|
| 116 |
+
extent_affine[3] = extent[3];
|
| 117 |
+
break;
|
| 118 |
+
case 2:
|
| 119 |
+
extent_affine[0] = extent[0];
|
| 120 |
+
extent_affine[1] = extent[1];
|
| 121 |
+
extent_affine[2] = extent[2];
|
| 122 |
+
extent_affine[3] = extent[3];
|
| 123 |
+
break;
|
| 124 |
+
case 3:
|
| 125 |
+
extent_affine[0] = extent[0];
|
| 126 |
+
extent_affine[1] = extent[1];
|
| 127 |
+
extent_affine[2] = extent[2];
|
| 128 |
+
extent_affine[3] = extent[3];
|
| 129 |
+
break;
|
| 130 |
+
default: break;
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
if (reduction_index == 3) {
|
| 134 |
+
reduction_contiguous = ReductionDeviceContiguousOperator(extent_affine);
|
| 135 |
+
}
|
| 136 |
+
else {
|
| 137 |
+
reduction_strided = ReductionDeviceStridedOperator(extent_affine);
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
/// Simple check to verify the object is initialized correctly
|
| 142 |
+
bool good() const {
|
| 143 |
+
if (reduction_index == 3) {
|
| 144 |
+
return reduction_contiguous.good();
|
| 145 |
+
}
|
| 146 |
+
return reduction_strided.good();
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
/// Size of one workspace
|
| 150 |
+
int64_t workspace_stride() const {
|
| 151 |
+
if (reduction_index == 3) {
|
| 152 |
+
return reduction_contiguous.workspace_stride();
|
| 153 |
+
}
|
| 154 |
+
else {
|
| 155 |
+
return reduction_strided.workspace_stride();
|
| 156 |
+
}
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
/// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs
|
| 160 |
+
int64_t workspace_size() const {
|
| 161 |
+
if (reduction_index == 3) {
|
| 162 |
+
return reduction_contiguous.workspace_size();
|
| 163 |
+
}
|
| 164 |
+
else {
|
| 165 |
+
return reduction_strided.workspace_size();
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
/// Helper to use overloaded function call operator
|
| 170 |
+
Status reduce(
|
| 171 |
+
TensorRef<ElementOutput, Layout> dst_ref,
|
| 172 |
+
TensorRef<ElementSource, Layout> src_ref,
|
| 173 |
+
void *device_workspace_ptr = nullptr,
|
| 174 |
+
ElementCompute reduction_identity = ElementCompute(),
|
| 175 |
+
ReductionOp reduction_op = ReductionOp(),
|
| 176 |
+
cudaStream_t stream = nullptr) {
|
| 177 |
+
|
| 178 |
+
int64_t src_stride[3];
|
| 179 |
+
int64_t dst_stride[3];
|
| 180 |
+
|
| 181 |
+
switch (reduction_index) {
|
| 182 |
+
case 0:
|
| 183 |
+
src_stride[0] = src_ref.stride()[1];
|
| 184 |
+
src_stride[1] = src_ref.stride()[0];
|
| 185 |
+
src_stride[2] = src_ref.stride()[2];
|
| 186 |
+
dst_stride[0] = dst_ref.stride()[1];
|
| 187 |
+
dst_stride[1] = dst_ref.stride()[0];
|
| 188 |
+
break;
|
| 189 |
+
case 1:
|
| 190 |
+
src_stride[0] = src_ref.stride()[2];
|
| 191 |
+
src_stride[1] = src_ref.stride()[0];
|
| 192 |
+
src_stride[2] = src_ref.stride()[1];
|
| 193 |
+
dst_stride[0] = dst_ref.stride()[2];
|
| 194 |
+
dst_stride[1] = dst_ref.stride()[0];
|
| 195 |
+
break;
|
| 196 |
+
case 2:
|
| 197 |
+
src_stride[0] = src_ref.stride()[2];
|
| 198 |
+
src_stride[1] = src_ref.stride()[1];
|
| 199 |
+
src_stride[2] = src_ref.stride()[0];
|
| 200 |
+
dst_stride[0] = dst_ref.stride()[2];
|
| 201 |
+
dst_stride[1] = dst_ref.stride()[1];
|
| 202 |
+
break;
|
| 203 |
+
case 3:
|
| 204 |
+
src_stride[0] = src_ref.stride()[2];
|
| 205 |
+
src_stride[1] = src_ref.stride()[1];
|
| 206 |
+
src_stride[2] = src_ref.stride()[0];
|
| 207 |
+
|
| 208 |
+
dst_stride[0] = dst_ref.stride()[2];
|
| 209 |
+
dst_stride[1] = dst_ref.stride()[1];
|
| 210 |
+
dst_stride[2] = dst_ref.stride()[0];
|
| 211 |
+
|
| 212 |
+
default: break;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
if (reduction_index == 3) {
|
| 216 |
+
return reduction_contiguous(
|
| 217 |
+
dst_ref.data(),
|
| 218 |
+
dst_stride,
|
| 219 |
+
src_ref.data(),
|
| 220 |
+
src_stride,
|
| 221 |
+
device_workspace_ptr,
|
| 222 |
+
reduction_identity,
|
| 223 |
+
reduction_op,
|
| 224 |
+
stream);
|
| 225 |
+
}
|
| 226 |
+
else {
|
| 227 |
+
return reduction_strided(
|
| 228 |
+
dst_ref.data(),
|
| 229 |
+
dst_stride,
|
| 230 |
+
src_ref.data(),
|
| 231 |
+
src_stride,
|
| 232 |
+
device_workspace_ptr,
|
| 233 |
+
reduction_identity,
|
| 234 |
+
reduction_op,
|
| 235 |
+
stream);
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
Status operator()(
|
| 240 |
+
TensorRef<ElementOutput, Layout> dst_ref,
|
| 241 |
+
TensorRef<ElementSource, Layout> src_ref,
|
| 242 |
+
void *device_workspace_ptr = nullptr,
|
| 243 |
+
ElementCompute reduction_identity = ElementCompute(),
|
| 244 |
+
ReductionOp reduction_op = ReductionOp(),
|
| 245 |
+
cudaStream_t stream = nullptr) {
|
| 246 |
+
|
| 247 |
+
return reduce(
|
| 248 |
+
dst_ref,
|
| 249 |
+
src_ref,
|
| 250 |
+
device_workspace_ptr,
|
| 251 |
+
reduction_identity,
|
| 252 |
+
reduction_op,
|
| 253 |
+
stream);
|
| 254 |
+
}
|
| 255 |
+
};
|
| 256 |
+
|
| 257 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 258 |
+
|
| 259 |
+
} // namespace device
|
| 260 |
+
} // namespace reduction
|
| 261 |
+
} // namespace cutlass
|
| 262 |
+
|
| 263 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 264 |
+
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Kernel performing a reduction over one or more ranks of an affine tensor
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/array.h"
|
| 39 |
+
#include "cutlass/fast_math.h"
|
| 40 |
+
#include "cutlass/numeric_types.h"
|
| 41 |
+
#include "cutlass/numeric_conversion.h"
|
| 42 |
+
#include "cutlass/device_kernel.h"
|
| 43 |
+
|
| 44 |
+
#include "cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h"
|
| 45 |
+
|
| 46 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace reduction {
|
| 50 |
+
namespace device {
|
| 51 |
+
|
| 52 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
/// Tensor reduction operator on layouts which are affine
|
| 55 |
+
template <
|
| 56 |
+
int Rank, ///< Rank of source tensor (e.g. NDHWC => 5)
|
| 57 |
+
int ReducedRank, ///< Rank of reduced tensor (e.g. ND => 2)
|
| 58 |
+
typename ElementOutput_,
|
| 59 |
+
typename ElementSource_,
|
| 60 |
+
typename ReductionOp_,
|
| 61 |
+
int VectorLength = 1,
|
| 62 |
+
typename ElementCompute_ = ElementOutput_,
|
| 63 |
+
int Threads = 256, ///< Number of participating threads
|
| 64 |
+
int BatchSize = 4 ///< Number of elements to load per batch
|
| 65 |
+
>
|
| 66 |
+
struct TensorReductionAffineContiguous {
|
| 67 |
+
|
| 68 |
+
static int const kRank = Rank;
|
| 69 |
+
static int const kReducedRank = ReducedRank;
|
| 70 |
+
static int const kVectorLength = VectorLength;
|
| 71 |
+
static int const kInnerRank = kRank - kReducedRank;
|
| 72 |
+
static int const kThreads = Threads;
|
| 73 |
+
static int const kBatchSize = BatchSize;
|
| 74 |
+
|
| 75 |
+
using ElementOutput = ElementOutput_;
|
| 76 |
+
using ElementSource = ElementSource_;
|
| 77 |
+
using ReductionOp = ReductionOp_;
|
| 78 |
+
using ElementCompute = ElementCompute_;
|
| 79 |
+
|
| 80 |
+
//
|
| 81 |
+
// Data members
|
| 82 |
+
//
|
| 83 |
+
|
| 84 |
+
/// Internal status field
|
| 85 |
+
Status status;
|
| 86 |
+
|
| 87 |
+
/// Extent of tensor in source layout
|
| 88 |
+
Coord<kRank> extent;
|
| 89 |
+
|
| 90 |
+
/// Number of points in the outer index space
|
| 91 |
+
int64_t outer_count;
|
| 92 |
+
|
| 93 |
+
/// Number of elements in the inner index space
|
| 94 |
+
int64_t inner_count;
|
| 95 |
+
|
| 96 |
+
/// Number of workspaces needed
|
| 97 |
+
int workspace_count;
|
| 98 |
+
|
| 99 |
+
/// CUDA Grid shape (.x => contiguous, .y => outer, .z => inner)
|
| 100 |
+
dim3 grid_shape;
|
| 101 |
+
|
| 102 |
+
/// CUDA Threadblock shape (.x => contiguous, .y => outer, .z => inner)
|
| 103 |
+
dim3 threadblock_shape;
|
| 104 |
+
|
| 105 |
+
/// CUDA grid shape for the final reduction step if needed
|
| 106 |
+
dim3 grid_final;
|
| 107 |
+
|
| 108 |
+
/// CUDA threadblock shape for the final reduction step if needed
|
| 109 |
+
dim3 threadblock_final;
|
| 110 |
+
|
| 111 |
+
private:
|
| 112 |
+
//
|
| 113 |
+
// Methods
|
| 114 |
+
//
|
| 115 |
+
|
| 116 |
+
/// Helper to reshape 'count' such that it is less than 2 x 'ext'
|
| 117 |
+
static int reshape_pow2(int ext, int count) {
|
| 118 |
+
if (ext > count) {
|
| 119 |
+
return 1;
|
| 120 |
+
}
|
| 121 |
+
int x = 1;
|
| 122 |
+
for (; count >= ext * 2; ) {
|
| 123 |
+
count >>= 1;
|
| 124 |
+
x <<= 1;
|
| 125 |
+
}
|
| 126 |
+
return x;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
public:
|
| 130 |
+
|
| 131 |
+
/// Default ctor
|
| 132 |
+
TensorReductionAffineContiguous():
|
| 133 |
+
status(Status::kErrorInvalidProblem),
|
| 134 |
+
extent(),
|
| 135 |
+
outer_count(0),
|
| 136 |
+
inner_count(0),
|
| 137 |
+
workspace_count(0),
|
| 138 |
+
grid_shape(0, 0, 0),
|
| 139 |
+
threadblock_shape(0, 0, 0) { }
|
| 140 |
+
|
| 141 |
+
/// Constructor
|
| 142 |
+
TensorReductionAffineContiguous(
|
| 143 |
+
Coord<kRank> extent_,
|
| 144 |
+
int target_threadblock_count = 128
|
| 145 |
+
):
|
| 146 |
+
status(Status::kSuccess),
|
| 147 |
+
extent(extent_),
|
| 148 |
+
outer_count(0),
|
| 149 |
+
inner_count(0),
|
| 150 |
+
workspace_count(0) {
|
| 151 |
+
|
| 152 |
+
//
|
| 153 |
+
// Plan the parallel mapping strategy.
|
| 154 |
+
//
|
| 155 |
+
|
| 156 |
+
outer_count = 1;
|
| 157 |
+
inner_count = 1;
|
| 158 |
+
|
| 159 |
+
// Compute number of elements in strided ranks
|
| 160 |
+
for (int p = 0; p < kReducedRank; ++p) {
|
| 161 |
+
outer_count *= extent[p];
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
for (int p = 0; p < kInnerRank; ++p) {
|
| 165 |
+
inner_count *= extent[kReducedRank + p];
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
int cta_count_x = 1;
|
| 169 |
+
int cta_count_y = 1;
|
| 170 |
+
int cta_count_z = 1;
|
| 171 |
+
|
| 172 |
+
int cta_threads_x = kThreads;
|
| 173 |
+
int cta_threads_y = 1;
|
| 174 |
+
int cta_threads_z = 1;
|
| 175 |
+
|
| 176 |
+
// Determine CTA shape
|
| 177 |
+
int64_t inner_vector_count = inner_count / kVectorLength;
|
| 178 |
+
|
| 179 |
+
// Priority 1. Assign threadblocks to outer indices if possible
|
| 180 |
+
if (outer_count > target_threadblock_count) {
|
| 181 |
+
cta_count_x = 1;
|
| 182 |
+
cta_count_y = target_threadblock_count;
|
| 183 |
+
cta_count_z = 1;
|
| 184 |
+
}
|
| 185 |
+
else {
|
| 186 |
+
|
| 187 |
+
cta_count_y = int(outer_count);
|
| 188 |
+
int remaining_ctas = target_threadblock_count / cta_count_y;
|
| 189 |
+
|
| 190 |
+
// Priority 2. Assign inner dimensions to one CTA
|
| 191 |
+
if (inner_vector_count > cta_threads_x) {
|
| 192 |
+
int64_t cta_z_bound = inner_vector_count / cta_threads_x;
|
| 193 |
+
if (cta_z_bound > remaining_ctas) {
|
| 194 |
+
cta_count_z = remaining_ctas;
|
| 195 |
+
}
|
| 196 |
+
else {
|
| 197 |
+
cta_count_z = int(cta_z_bound);
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
else {
|
| 201 |
+
cta_threads_x = reshape_pow2(int(inner_vector_count), cta_threads_x);
|
| 202 |
+
cta_count_z = 1;
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
grid_shape = dim3(cta_count_x, cta_count_y, cta_count_z);
|
| 207 |
+
threadblock_shape = dim3(cta_threads_x, cta_threads_y, cta_threads_z);
|
| 208 |
+
|
| 209 |
+
workspace_count = (cta_count_z > 1 ? cta_count_z : 0);
|
| 210 |
+
|
| 211 |
+
// Determine shape of final reduction kernel if needed
|
| 212 |
+
if (workspace_count) {
|
| 213 |
+
|
| 214 |
+
int final_threads = kThreads;
|
| 215 |
+
int final_ctas = 1;
|
| 216 |
+
|
| 217 |
+
if (outer_count > kThreads) {
|
| 218 |
+
final_ctas = int(outer_count + kThreads - 1) / kThreads;
|
| 219 |
+
}
|
| 220 |
+
else {
|
| 221 |
+
final_threads = int(outer_count);
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
grid_final = dim3(final_ctas, 1, 1);
|
| 225 |
+
threadblock_final = dim3(final_threads, 1, 1);
|
| 226 |
+
}
|
| 227 |
+
else {
|
| 228 |
+
grid_final = dim3(0, 0, 0);
|
| 229 |
+
threadblock_final = dim3(0, 0, 0);
|
| 230 |
+
}
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
/// Simple check to verify the object is initialized correctly
|
| 234 |
+
bool good() const {
|
| 235 |
+
return status == Status::kSuccess;
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
/// Size (in bytes) of <outer_count> workspace elements which are densely packed together
|
| 239 |
+
int64_t workspace_stride() const {
|
| 240 |
+
|
| 241 |
+
// Error condition
|
| 242 |
+
if (!good()) {
|
| 243 |
+
return 0;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
return outer_count * sizeof_bits<ElementCompute>::value / 8;
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
/// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs
|
| 250 |
+
int64_t workspace_size() const {
|
| 251 |
+
|
| 252 |
+
// Error condition
|
| 253 |
+
if (!good()) {
|
| 254 |
+
return 0;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
// No reduction across CTAs
|
| 258 |
+
if (grid_shape.z == 1) {
|
| 259 |
+
return 0;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
return workspace_stride() * grid_shape.z;
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
/// Performs a reduction
|
| 266 |
+
Status reduce(
|
| 267 |
+
ElementOutput *dst_ptr, ///< Pointer to destination tensor
|
| 268 |
+
int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1)
|
| 269 |
+
ElementSource const *src_ptr, ///< Pointer to source tensor
|
| 270 |
+
int64_t src_stride[], ///< Stride vector (of length kRank - 1)
|
| 271 |
+
void *device_workspace_ptr = nullptr, ///< Device workspace
|
| 272 |
+
ElementCompute reduction_identity = ElementCompute(), ///< Reduction identity element
|
| 273 |
+
ReductionOp reduction_op = ReductionOp(), ///< Reduction operator
|
| 274 |
+
cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched
|
| 275 |
+
|
| 276 |
+
// Initial status check
|
| 277 |
+
if (!good()) {
|
| 278 |
+
return status;
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
// Guard against null workspace
|
| 282 |
+
if (workspace_count > 1 && device_workspace_ptr == nullptr) {
|
| 283 |
+
return Status::kErrorWorkspaceNull;
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
// Define reduction kernel
|
| 287 |
+
using ReductionKernel = kernel::TensorReductionAffineContiguous<
|
| 288 |
+
kRank,
|
| 289 |
+
kReducedRank,
|
| 290 |
+
ElementOutput,
|
| 291 |
+
ElementSource,
|
| 292 |
+
ReductionOp,
|
| 293 |
+
kVectorLength,
|
| 294 |
+
ElementCompute,
|
| 295 |
+
kThreads>;
|
| 296 |
+
|
| 297 |
+
using FinalReductionKernel = kernel::TensorReductionAffineContiguousFinal<
|
| 298 |
+
kRank,
|
| 299 |
+
kReducedRank,
|
| 300 |
+
ElementOutput,
|
| 301 |
+
ElementSource,
|
| 302 |
+
ReductionOp,
|
| 303 |
+
kVectorLength,
|
| 304 |
+
ElementCompute,
|
| 305 |
+
kThreads>;
|
| 306 |
+
|
| 307 |
+
using Params = typename ReductionKernel::Params;
|
| 308 |
+
|
| 309 |
+
// Construct the parameters
|
| 310 |
+
Params params(
|
| 311 |
+
extent,
|
| 312 |
+
dst_ptr,
|
| 313 |
+
dst_stride,
|
| 314 |
+
src_ptr,
|
| 315 |
+
src_stride,
|
| 316 |
+
static_cast<ElementCompute *>(device_workspace_ptr),
|
| 317 |
+
workspace_stride(),
|
| 318 |
+
workspace_count,
|
| 319 |
+
reduction_op,
|
| 320 |
+
reduction_identity);
|
| 321 |
+
|
| 322 |
+
// Shared memory size
|
| 323 |
+
int shared_mem_bytes = sizeof(typename ReductionKernel::SharedStorage);
|
| 324 |
+
|
| 325 |
+
// Launch the kernel
|
| 326 |
+
cutlass::arch::synclog_setup();
|
| 327 |
+
Kernel<ReductionKernel><<< grid_shape, threadblock_shape, shared_mem_bytes, stream >>>(params);
|
| 328 |
+
|
| 329 |
+
// Check error condition
|
| 330 |
+
if (cudaPeekAtLastError() == cudaSuccess) {
|
| 331 |
+
status = Status::kSuccess;
|
| 332 |
+
}
|
| 333 |
+
else {
|
| 334 |
+
status = Status::kErrorInternal;
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
// Final reduction kernel
|
| 338 |
+
if (workspace_count) {
|
| 339 |
+
Kernel<FinalReductionKernel><<< grid_final, threadblock_final, 0, stream >>>(params);
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
// Check error condition
|
| 343 |
+
if (cudaPeekAtLastError() == cudaSuccess) {
|
| 344 |
+
status = Status::kSuccess;
|
| 345 |
+
}
|
| 346 |
+
else {
|
| 347 |
+
status = Status::kErrorInternal;
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
return status;
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
/// Helper to use overloaded function call operator
|
| 354 |
+
Status operator()(
|
| 355 |
+
ElementOutput *dst_ptr, ///< Pointer to destination tensor
|
| 356 |
+
int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1)
|
| 357 |
+
ElementSource const *src_ptr, ///< Pointer to source tensor
|
| 358 |
+
int64_t src_stride[], ///< Stride vector (of length kRank - 1)
|
| 359 |
+
void *device_workspace_ptr = nullptr, ///< Pointer to device workspace
|
| 360 |
+
ElementCompute reduction_identity = ElementCompute(), ///< Reduction identity element
|
| 361 |
+
ReductionOp reduction_op = ReductionOp(), ///< Reduction operator
|
| 362 |
+
cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched
|
| 363 |
+
|
| 364 |
+
return reduce(dst_ptr, dst_stride, src_ptr, src_stride, device_workspace_ptr, reduction_identity, reduction_op, stream);
|
| 365 |
+
}
|
| 366 |
+
};
|
| 367 |
+
|
| 368 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 369 |
+
|
| 370 |
+
} // namespace device
|
| 371 |
+
} // namespace reduction
|
| 372 |
+
} // namespace cutlass
|
| 373 |
+
|
| 374 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Kernel performing a reduction over one or more ranks of an affine tensor
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/array.h"
|
| 39 |
+
#include "cutlass/fast_math.h"
|
| 40 |
+
#include "cutlass/numeric_types.h"
|
| 41 |
+
#include "cutlass/numeric_conversion.h"
|
| 42 |
+
#include "cutlass/device_kernel.h"
|
| 43 |
+
|
| 44 |
+
#include "cutlass/reduction/kernel/tensor_reduce_affine_strided.h"
|
| 45 |
+
|
| 46 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace reduction {
|
| 50 |
+
namespace device {
|
| 51 |
+
|
| 52 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
/// Tensor reduction operator on layouts which are affine
|
| 55 |
+
template <
|
| 56 |
+
int Rank, ///< Rank of source tensor (e.g. NDHWC => 5)
|
| 57 |
+
int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2)
|
| 58 |
+
typename ElementOutput_,
|
| 59 |
+
typename ElementSource_,
|
| 60 |
+
typename ReductionOp_,
|
| 61 |
+
int VectorLength = 1,
|
| 62 |
+
typename ElementCompute_ = ElementOutput_,
|
| 63 |
+
int Threads = 256, ///< Number of participating threads
|
| 64 |
+
int BatchSize = 4 ///< Number of elements to load per batch
|
| 65 |
+
>
|
| 66 |
+
struct TensorReductionAffineStrided {
|
| 67 |
+
|
| 68 |
+
static int const kRank = Rank;
|
| 69 |
+
static int const kReducedRank = ReducedRank;
|
| 70 |
+
static int const kVectorLength = VectorLength;
|
| 71 |
+
static int const kInnerRank = kRank - kReducedRank;
|
| 72 |
+
static int const kThreads = Threads;
|
| 73 |
+
static int const kBatchSize = BatchSize;
|
| 74 |
+
|
| 75 |
+
using ElementOutput = ElementOutput_;
|
| 76 |
+
using ElementSource = ElementSource_;
|
| 77 |
+
using ReductionOp = ReductionOp_;
|
| 78 |
+
using ElementCompute = ElementCompute_;
|
| 79 |
+
|
| 80 |
+
//
|
| 81 |
+
// Data members
|
| 82 |
+
//
|
| 83 |
+
|
| 84 |
+
/// Internal status field
|
| 85 |
+
Status status;
|
| 86 |
+
|
| 87 |
+
/// Extent of tensor in source layout
|
| 88 |
+
Coord<kRank> extent;
|
| 89 |
+
|
| 90 |
+
/// Number of points in the outer index space
|
| 91 |
+
int64_t outer_count;
|
| 92 |
+
|
| 93 |
+
/// Number of elements in the inner index space
|
| 94 |
+
int64_t inner_count;
|
| 95 |
+
|
| 96 |
+
/// Number of workspaces needed
|
| 97 |
+
int workspace_count;
|
| 98 |
+
|
| 99 |
+
/// CUDA Grid shape (.x => contiguous, .y => outer, .z => inner)
|
| 100 |
+
dim3 grid_shape;
|
| 101 |
+
|
| 102 |
+
/// CUDA Threadblock shape (.x => contiguous, .y => outer, .z => inner)
|
| 103 |
+
dim3 threadblock_shape;
|
| 104 |
+
|
| 105 |
+
/// CUDA grid shape for the final reduction step if needed
|
| 106 |
+
dim3 grid_final;
|
| 107 |
+
|
| 108 |
+
/// CUDA threadblock shape for the final reduction step if needed
|
| 109 |
+
dim3 threadblock_final;
|
| 110 |
+
|
| 111 |
+
private:
|
| 112 |
+
//
|
| 113 |
+
// Methods
|
| 114 |
+
//
|
| 115 |
+
|
| 116 |
+
/// Helper to reshape 'count' such that it is less than 2 x 'ext'
|
| 117 |
+
static int reshape_pow2(int ext, int count) {
|
| 118 |
+
if (ext > count) {
|
| 119 |
+
return 1;
|
| 120 |
+
}
|
| 121 |
+
int x = 1;
|
| 122 |
+
for (; count >= ext * 2; ) {
|
| 123 |
+
count >>= 1;
|
| 124 |
+
x <<= 1;
|
| 125 |
+
}
|
| 126 |
+
return x;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
public:
|
| 130 |
+
|
| 131 |
+
/// Default ctor
|
| 132 |
+
TensorReductionAffineStrided():
|
| 133 |
+
status(Status::kErrorInvalidProblem),
|
| 134 |
+
extent(),
|
| 135 |
+
outer_count(0),
|
| 136 |
+
inner_count(0),
|
| 137 |
+
workspace_count(0),
|
| 138 |
+
grid_shape(0, 0, 0),
|
| 139 |
+
threadblock_shape(0, 0, 0) { }
|
| 140 |
+
|
| 141 |
+
/// Constructor
|
| 142 |
+
TensorReductionAffineStrided(
|
| 143 |
+
Coord<kRank> extent_,
|
| 144 |
+
int target_threadblock_count = 128
|
| 145 |
+
):
|
| 146 |
+
status(Status::kSuccess),
|
| 147 |
+
extent(extent_),
|
| 148 |
+
outer_count(0),
|
| 149 |
+
inner_count(0),
|
| 150 |
+
workspace_count(0) {
|
| 151 |
+
|
| 152 |
+
//
|
| 153 |
+
// Plan the parallel mapping strategy.
|
| 154 |
+
//
|
| 155 |
+
|
| 156 |
+
outer_count = 1;
|
| 157 |
+
inner_count = 1;
|
| 158 |
+
|
| 159 |
+
// Compute number of elements in strided ranks
|
| 160 |
+
for (int p = 0; p < kReducedRank - 1; ++p) {
|
| 161 |
+
outer_count *= extent[p];
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
for (int p = 0; p < kInnerRank; ++p) {
|
| 165 |
+
inner_count *= extent[kReducedRank + p - 1];
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
// Compute plan for the reduction
|
| 169 |
+
int extent_c = extent[kRank - 1];
|
| 170 |
+
int vectors_c = (extent_c -1 + kVectorLength) / kVectorLength;
|
| 171 |
+
|
| 172 |
+
// Determine CTA shape
|
| 173 |
+
int cta_width = kThreads * kVectorLength;
|
| 174 |
+
int cta_ways = reshape_pow2(extent_c, cta_width);
|
| 175 |
+
int cta_threads_x = kThreads / cta_ways;
|
| 176 |
+
|
| 177 |
+
threadblock_shape = dim3(cta_threads_x, 1, std::min(cta_ways, 64));
|
| 178 |
+
|
| 179 |
+
// This leads to an error.
|
| 180 |
+
if (threadblock_shape.z > 1) {
|
| 181 |
+
if (threadblock_shape.y != 1) {
|
| 182 |
+
status = Status::kErrorInternal;
|
| 183 |
+
return;
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
// Determine grid shape
|
| 188 |
+
int cta_count_x = (vectors_c + cta_threads_x - 1) / cta_threads_x;
|
| 189 |
+
int cta_count_y = std::max(1, target_threadblock_count / cta_count_x);
|
| 190 |
+
|
| 191 |
+
// Limit the number of CTAs assigned to outer dimension
|
| 192 |
+
if (int64_t(cta_count_y * threadblock_shape.y) > outer_count) {
|
| 193 |
+
cta_count_y = int(outer_count + threadblock_shape.y - 1) / threadblock_shape.y;
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
// Limit the number of CTAs assigned to inner dimension
|
| 197 |
+
int cta_count_z = std::max(1, target_threadblock_count / cta_count_y);
|
| 198 |
+
if (int64_t(cta_count_z * threadblock_shape.z) > inner_count) {
|
| 199 |
+
cta_count_z = int(inner_count + threadblock_shape.z - 1) / threadblock_shape.z;
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
grid_shape = dim3(cta_count_x, cta_count_y, cta_count_z);
|
| 203 |
+
workspace_count = (cta_count_z > 1 ? cta_count_z : 0);
|
| 204 |
+
|
| 205 |
+
// Determine shape of final reduction kernel if needed
|
| 206 |
+
grid_final = dim3(cta_count_x, int(outer_count));
|
| 207 |
+
threadblock_final = dim3(cta_threads_x, 1, 1);
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
/// Simple check to verify the object is initialized correctly
|
| 211 |
+
bool good() const {
|
| 212 |
+
return status == Status::kSuccess;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
/// Size of one CTA's workspace
|
| 216 |
+
int64_t workspace_stride() const {
|
| 217 |
+
|
| 218 |
+
// Error condition
|
| 219 |
+
if (!good()) {
|
| 220 |
+
return 0;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
int vector_size_bytes = kVectorLength * sizeof_bits<ElementCompute>::value / 8;
|
| 224 |
+
|
| 225 |
+
return extent[kRank - 1] * vector_size_bytes;
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
/// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs
|
| 229 |
+
int64_t workspace_size() const {
|
| 230 |
+
|
| 231 |
+
// Error condition
|
| 232 |
+
if (!good()) {
|
| 233 |
+
return 0;
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
// No reduction across CTAs
|
| 237 |
+
if (grid_shape.z == 1) {
|
| 238 |
+
return 0;
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
return workspace_stride() * outer_count * grid_shape.z;
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
/// Performs a reduction
|
| 245 |
+
Status reduce(
|
| 246 |
+
ElementOutput *dst_ptr, ///< Pointer to destination tensor
|
| 247 |
+
int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1)
|
| 248 |
+
ElementSource const *src_ptr, ///< Pointer to source tensor
|
| 249 |
+
int64_t src_stride[], ///< Stride vector (of length kRank - 1)
|
| 250 |
+
void *device_workspace_ptr = nullptr, ///< Device workspace
|
| 251 |
+
ElementCompute reduction_identity = ElementCompute(), ///< Reduciton identity
|
| 252 |
+
ReductionOp reduction_op = ReductionOp(), ///< Reduction operator
|
| 253 |
+
cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched
|
| 254 |
+
|
| 255 |
+
// Initial status check
|
| 256 |
+
if (!good()) {
|
| 257 |
+
return status;
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
// Guard against null workspace
|
| 261 |
+
if (workspace_count > 1 && device_workspace_ptr == nullptr) {
|
| 262 |
+
return Status::kErrorWorkspaceNull;
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
// Define reduction kernel
|
| 266 |
+
using ReductionKernel = kernel::TensorReductionAffineStrided<
|
| 267 |
+
kRank,
|
| 268 |
+
kReducedRank,
|
| 269 |
+
ElementOutput,
|
| 270 |
+
ElementSource,
|
| 271 |
+
ReductionOp,
|
| 272 |
+
kVectorLength,
|
| 273 |
+
ElementCompute,
|
| 274 |
+
kThreads>;
|
| 275 |
+
|
| 276 |
+
using FinalReductionKernel = kernel::TensorReductionAffineStridedFinal<
|
| 277 |
+
kRank,
|
| 278 |
+
kReducedRank,
|
| 279 |
+
ElementOutput,
|
| 280 |
+
ElementSource,
|
| 281 |
+
ReductionOp,
|
| 282 |
+
kVectorLength,
|
| 283 |
+
ElementCompute,
|
| 284 |
+
kThreads>;
|
| 285 |
+
|
| 286 |
+
using Params = typename ReductionKernel::Params;
|
| 287 |
+
|
| 288 |
+
// Construct the parameters
|
| 289 |
+
Params params(
|
| 290 |
+
extent,
|
| 291 |
+
dst_ptr,
|
| 292 |
+
dst_stride,
|
| 293 |
+
src_ptr,
|
| 294 |
+
src_stride,
|
| 295 |
+
static_cast<ElementCompute *>(device_workspace_ptr),
|
| 296 |
+
workspace_stride(),
|
| 297 |
+
workspace_count,
|
| 298 |
+
reduction_op,
|
| 299 |
+
reduction_identity);
|
| 300 |
+
|
| 301 |
+
// Shared memory size
|
| 302 |
+
int shared_mem_bytes = sizeof(typename ReductionKernel::SharedStorage);
|
| 303 |
+
|
| 304 |
+
// Launch the kernel
|
| 305 |
+
cutlass::arch::synclog_setup();
|
| 306 |
+
Kernel<ReductionKernel><<< grid_shape, threadblock_shape, shared_mem_bytes, stream >>>(params);
|
| 307 |
+
|
| 308 |
+
// Check error condition
|
| 309 |
+
if (cudaPeekAtLastError() == cudaSuccess) {
|
| 310 |
+
status = Status::kSuccess;
|
| 311 |
+
}
|
| 312 |
+
else {
|
| 313 |
+
status = Status::kErrorInternal;
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
// Final reduction kernel
|
| 317 |
+
if (workspace_count) {
|
| 318 |
+
|
| 319 |
+
Kernel<FinalReductionKernel><<< grid_final, threadblock_final, 0, stream >>>(params);
|
| 320 |
+
|
| 321 |
+
// Check error condition
|
| 322 |
+
if (cudaPeekAtLastError() == cudaSuccess) {
|
| 323 |
+
status = Status::kSuccess;
|
| 324 |
+
}
|
| 325 |
+
else {
|
| 326 |
+
status = Status::kErrorInternal;
|
| 327 |
+
}
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
return status;
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
/// Helper to use overloaded function call operator
|
| 334 |
+
Status operator()(
|
| 335 |
+
ElementOutput *dst_ptr, ///< Pointer to destination tensor
|
| 336 |
+
int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1)
|
| 337 |
+
ElementSource const *src_ptr, ///< Pointer to source tensor
|
| 338 |
+
int64_t src_stride[], ///< Stride vector (of length kRank - 1)
|
| 339 |
+
void *device_workspace_ptr = nullptr, ///< Pointer to device workspace
|
| 340 |
+
ElementCompute reduction_identity = ElementCompute(), ///< Reduciton identity
|
| 341 |
+
ReductionOp reduction_op = ReductionOp(), ///< Reduction operator
|
| 342 |
+
cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched
|
| 343 |
+
|
| 344 |
+
return reduce(
|
| 345 |
+
dst_ptr,
|
| 346 |
+
dst_stride,
|
| 347 |
+
src_ptr,
|
| 348 |
+
src_stride,
|
| 349 |
+
device_workspace_ptr,
|
| 350 |
+
reduction_identity,
|
| 351 |
+
reduction_op,
|
| 352 |
+
stream);
|
| 353 |
+
}
|
| 354 |
+
};
|
| 355 |
+
|
| 356 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 357 |
+
|
| 358 |
+
} // namespace device
|
| 359 |
+
} // namespace reduction
|
| 360 |
+
} // namespace cutlass
|
| 361 |
+
|
| 362 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Kernel performing a final reduction for softmax
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/numeric_types.h"
|
| 39 |
+
#include "cutlass/array.h"
|
| 40 |
+
#include "cutlass/functional.h"
|
| 41 |
+
#include "cutlass/matrix_shape.h"
|
| 42 |
+
#include "cutlass/numeric_conversion.h"
|
| 43 |
+
#include "cutlass/arch/memory.h"
|
| 44 |
+
#include "cutlass/arch/memory_sm75.h"
|
| 45 |
+
|
| 46 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace reduction {
|
| 50 |
+
namespace kernel {
|
| 51 |
+
|
| 52 |
+
template <
|
| 53 |
+
typename ElementNorm_,
|
| 54 |
+
typename ElementSum_,
|
| 55 |
+
typename ElementSoftmaxCompute_,
|
| 56 |
+
typename ThreadblockShape_,
|
| 57 |
+
bool GroupedProblem = false
|
| 58 |
+
>
|
| 59 |
+
class ApplySoftmaxFinalReduction {
|
| 60 |
+
public:
|
| 61 |
+
|
| 62 |
+
using ElementNorm = ElementNorm_;
|
| 63 |
+
using ElementSum = ElementSum_;
|
| 64 |
+
using ElementSoftmaxCompute = ElementSoftmaxCompute_;
|
| 65 |
+
using ThreadblockShape = ThreadblockShape_;
|
| 66 |
+
static const bool isGroupedProblem = GroupedProblem;
|
| 67 |
+
|
| 68 |
+
//
|
| 69 |
+
// Arguments
|
| 70 |
+
//
|
| 71 |
+
|
| 72 |
+
struct Arguments {
|
| 73 |
+
|
| 74 |
+
cutlass::gemm::GemmCoord* problem_sizes{nullptr};
|
| 75 |
+
cutlass::gemm::GemmCoord problem_size{};
|
| 76 |
+
ElementNorm* block_Norm{nullptr};
|
| 77 |
+
ElementSum* block_Sum{nullptr};
|
| 78 |
+
int64_t* offset_Norm_Device{nullptr};
|
| 79 |
+
int64_t* offset_Sum_Device{nullptr};
|
| 80 |
+
int64_t batch_stride_Max{0};
|
| 81 |
+
int64_t batch_stride_Sum{0};
|
| 82 |
+
|
| 83 |
+
//
|
| 84 |
+
// Methods
|
| 85 |
+
//
|
| 86 |
+
Arguments() { }
|
| 87 |
+
|
| 88 |
+
// Non-grouped constructor without batching
|
| 89 |
+
Arguments(
|
| 90 |
+
cutlass::gemm::GemmCoord problem_size,
|
| 91 |
+
ElementNorm* block_Norm,
|
| 92 |
+
ElementSum* block_Sum
|
| 93 |
+
):
|
| 94 |
+
problem_size(problem_size),
|
| 95 |
+
block_Norm(block_Norm),
|
| 96 |
+
block_Sum(block_Sum),
|
| 97 |
+
problem_sizes(nullptr),
|
| 98 |
+
offset_Norm_Device(nullptr),
|
| 99 |
+
offset_Sum_Device(nullptr),
|
| 100 |
+
batch_stride_Max(0),
|
| 101 |
+
batch_stride_Sum(0)
|
| 102 |
+
{
|
| 103 |
+
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
// Non-grouped constructor with batching
|
| 107 |
+
Arguments(
|
| 108 |
+
cutlass::gemm::GemmCoord problem_size,
|
| 109 |
+
ElementNorm* block_Norm,
|
| 110 |
+
ElementSum* block_Sum,
|
| 111 |
+
int64_t batch_stride_Max,
|
| 112 |
+
int64_t batch_stride_Sum
|
| 113 |
+
):
|
| 114 |
+
problem_size(problem_size),
|
| 115 |
+
block_Norm(block_Norm),
|
| 116 |
+
block_Sum(block_Sum),
|
| 117 |
+
batch_stride_Max(batch_stride_Max),
|
| 118 |
+
batch_stride_Sum(batch_stride_Sum),
|
| 119 |
+
problem_sizes(nullptr),
|
| 120 |
+
offset_Norm_Device(nullptr),
|
| 121 |
+
offset_Sum_Device(nullptr)
|
| 122 |
+
{
|
| 123 |
+
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
// Grouped constructor
|
| 128 |
+
Arguments(
|
| 129 |
+
cutlass::gemm::GemmCoord *problem_sizes,
|
| 130 |
+
ElementNorm* block_Norm,
|
| 131 |
+
ElementSum* block_Sum,
|
| 132 |
+
int64_t* offset_Norm_Device,
|
| 133 |
+
int64_t* offset_Sum_Device
|
| 134 |
+
):
|
| 135 |
+
problem_sizes(problem_sizes),
|
| 136 |
+
problem_size(cutlass::gemm::GemmCoord(0, 0, 0)),
|
| 137 |
+
block_Norm(block_Norm),
|
| 138 |
+
block_Sum(block_Sum),
|
| 139 |
+
offset_Norm_Device(offset_Norm_Device),
|
| 140 |
+
offset_Sum_Device(offset_Sum_Device)
|
| 141 |
+
{
|
| 142 |
+
|
| 143 |
+
}
|
| 144 |
+
};
|
| 145 |
+
|
| 146 |
+
struct SharedStorage {
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
};
|
| 150 |
+
|
| 151 |
+
//
|
| 152 |
+
// Params struct
|
| 153 |
+
//
|
| 154 |
+
|
| 155 |
+
struct Params {
|
| 156 |
+
Arguments args;
|
| 157 |
+
|
| 158 |
+
//
|
| 159 |
+
// Methods
|
| 160 |
+
//
|
| 161 |
+
Params() { }
|
| 162 |
+
|
| 163 |
+
Params(Arguments const &args_): args(args_) { }
|
| 164 |
+
};
|
| 165 |
+
|
| 166 |
+
private:
|
| 167 |
+
|
| 168 |
+
public:
|
| 169 |
+
|
| 170 |
+
CUTLASS_DEVICE
|
| 171 |
+
ApplySoftmaxFinalReduction() { }
|
| 172 |
+
|
| 173 |
+
CUTLASS_DEVICE
|
| 174 |
+
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
| 175 |
+
|
| 176 |
+
apply(params, shared_storage);
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
private:
|
| 180 |
+
|
| 181 |
+
/// Full reduction
|
| 182 |
+
CUTLASS_DEVICE
|
| 183 |
+
void apply(Params const ¶ms, SharedStorage &shared_storage) {
|
| 184 |
+
|
| 185 |
+
int tid = threadIdx.x;
|
| 186 |
+
int bid = blockIdx.x;
|
| 187 |
+
int bdim = blockDim.x;
|
| 188 |
+
|
| 189 |
+
int block_batch = blockIdx.z;
|
| 190 |
+
|
| 191 |
+
// defining three vars for a general reduction module
|
| 192 |
+
cutlass::gemm::GemmCoord problem_size = isGroupedProblem ? params.args.problem_sizes[bid] : params.args.problem_size;
|
| 193 |
+
int m_dim_in_loop = isGroupedProblem ? problem_size.m() : tid + bdim;
|
| 194 |
+
int access_offset = isGroupedProblem ? 0 : bid * bdim;
|
| 195 |
+
|
| 196 |
+
if (!isGroupedProblem && access_offset + tid >= problem_size.m()) return;
|
| 197 |
+
|
| 198 |
+
ElementNorm *curr_ptr_Max = isGroupedProblem ? \
|
| 199 |
+
params.args.block_Norm + params.args.offset_Norm_Device[bid] : \
|
| 200 |
+
params.args.block_Norm + block_batch * params.args.batch_stride_Max;
|
| 201 |
+
ElementSum *curr_ptr_Sum = isGroupedProblem ? \
|
| 202 |
+
params.args.block_Sum + params.args.offset_Sum_Device[bid] : \
|
| 203 |
+
params.args.block_Sum + block_batch * params.args.batch_stride_Sum;
|
| 204 |
+
|
| 205 |
+
int threadblock_num = (problem_size.n() + ThreadblockShape::kN - 1) / ThreadblockShape::kN;
|
| 206 |
+
|
| 207 |
+
using ConvertSumOutput = cutlass::NumericConverter<ElementSum, ElementSoftmaxCompute>;
|
| 208 |
+
using ConvertNormOutput = cutlass::NumericConverter<ElementNorm, ElementSoftmaxCompute>;
|
| 209 |
+
|
| 210 |
+
using ConvertSum = cutlass::NumericConverter<ElementSoftmaxCompute, ElementSum>;
|
| 211 |
+
using ConvertNorm = cutlass::NumericConverter<ElementSoftmaxCompute, ElementNorm>;
|
| 212 |
+
|
| 213 |
+
ConvertSum convert_sum;
|
| 214 |
+
ConvertNorm convert_norm;
|
| 215 |
+
|
| 216 |
+
ConvertSumOutput convert_sum_output;
|
| 217 |
+
ConvertNormOutput convert_norm_output;
|
| 218 |
+
|
| 219 |
+
uint32_t float_max_bits = 0xff7fffff;
|
| 220 |
+
float min_float = reinterpret_cast<float const &>(float_max_bits);
|
| 221 |
+
|
| 222 |
+
CUTLASS_PRAGMA_UNROLL
|
| 223 |
+
for (int idx_m = tid; idx_m < m_dim_in_loop; idx_m += bdim) {
|
| 224 |
+
ElementNorm *access_n = curr_ptr_Max + idx_m + access_offset;
|
| 225 |
+
ElementSum *access_s = curr_ptr_Sum + idx_m + access_offset;
|
| 226 |
+
ElementNorm *access_n_bak = access_n;
|
| 227 |
+
ElementSum *access_s_bak = access_s;
|
| 228 |
+
ElementSoftmaxCompute max_val = ElementSoftmaxCompute(min_float);
|
| 229 |
+
ElementSoftmaxCompute sum_val = ElementSoftmaxCompute(0);
|
| 230 |
+
ElementNorm fetch_n;
|
| 231 |
+
ElementSum fetch_s;
|
| 232 |
+
|
| 233 |
+
CUTLASS_PRAGMA_UNROLL
|
| 234 |
+
for (int idx_n = 0; idx_n < threadblock_num; idx_n++) {
|
| 235 |
+
cutlass::arch::global_load<ElementNorm, sizeof(ElementNorm)>(fetch_n, access_n, true);
|
| 236 |
+
max_val = cutlass::fast_max(max_val, convert_norm(fetch_n));
|
| 237 |
+
access_n += problem_size.m();
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
access_n = access_n_bak;
|
| 241 |
+
|
| 242 |
+
CUTLASS_PRAGMA_UNROLL
|
| 243 |
+
for (int idx_n = 0; idx_n < threadblock_num; idx_n++) {
|
| 244 |
+
cutlass::arch::global_load<ElementNorm, sizeof(ElementNorm)>(fetch_n, access_n, true);
|
| 245 |
+
cutlass::arch::global_load<ElementSum, sizeof(ElementSum)>(fetch_s, access_s, true);
|
| 246 |
+
sum_val += convert_sum(fetch_s) * cutlass::fast_exp(convert_norm(fetch_n) - max_val);
|
| 247 |
+
access_n += problem_size.m();
|
| 248 |
+
access_s += problem_size.m();
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
ElementSoftmaxCompute inv_sum = cutlass::constants::one<ElementSoftmaxCompute>() / sum_val;
|
| 252 |
+
|
| 253 |
+
access_n = access_n_bak;
|
| 254 |
+
access_s = access_s_bak;
|
| 255 |
+
|
| 256 |
+
access_n[0] = convert_norm_output(max_val);
|
| 257 |
+
access_s[0] = convert_sum_output(inv_sum);
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
}
|
| 261 |
+
};
|
| 262 |
+
|
| 263 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 264 |
+
|
| 265 |
+
} // namespace kernel
|
| 266 |
+
} // namespace reduction
|
| 267 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Kernel performing a reduction over densely packed tensors in global memory
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/tensor_ref.h"
|
| 39 |
+
#include "cutlass/numeric_types.h"
|
| 40 |
+
#include "cutlass/array.h"
|
| 41 |
+
#include "cutlass/functional.h"
|
| 42 |
+
#include "cutlass/matrix_shape.h"
|
| 43 |
+
#include "cutlass/numeric_conversion.h"
|
| 44 |
+
|
| 45 |
+
#include "cutlass/layout/matrix.h"
|
| 46 |
+
|
| 47 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
namespace cutlass {
|
| 50 |
+
namespace reduction {
|
| 51 |
+
namespace kernel {
|
| 52 |
+
|
| 53 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 54 |
+
|
| 55 |
+
template <
|
| 56 |
+
typename Shape_, ///< shape of CTA (concept: MatrixShape)
|
| 57 |
+
typename OutputOp_ , ///< output operator (concept: epilogue::thread operator)
|
| 58 |
+
typename ReductionOp_, ///< reduction operator (concept: ReductionOperator)
|
| 59 |
+
int PartitionsPerStage = 4 ///< number of partitions to issue
|
| 60 |
+
>
|
| 61 |
+
class ReduceSplitK {
|
| 62 |
+
public:
|
| 63 |
+
|
| 64 |
+
using Shape = Shape_;
|
| 65 |
+
using ReductionOp = ReductionOp_;
|
| 66 |
+
using OutputOp = OutputOp_;
|
| 67 |
+
static int const kElementsPerAccess = OutputOp::kCount;
|
| 68 |
+
static int const kPartitionsPerStage = PartitionsPerStage;
|
| 69 |
+
|
| 70 |
+
using ElementWorkspace = typename ReductionOp::Element;
|
| 71 |
+
using ElementAccumulator = typename ReductionOp::ElementAccumulator;
|
| 72 |
+
using ElementOutput = typename OutputOp::ElementOutput;
|
| 73 |
+
|
| 74 |
+
using WorkspaceTensorRef = TensorRef<ElementWorkspace, layout::RowMajor>;
|
| 75 |
+
using OutputTensorRef = TensorRef<ElementOutput, layout::RowMajor>;
|
| 76 |
+
using StrideIndex = typename WorkspaceTensorRef::Layout::Stride::Index;
|
| 77 |
+
|
| 78 |
+
using FragmentWorkspace = AlignedArray<ElementWorkspace, kElementsPerAccess>;
|
| 79 |
+
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
|
| 80 |
+
using FragmentOutput = AlignedArray<ElementOutput, kElementsPerAccess>;
|
| 81 |
+
|
| 82 |
+
//
|
| 83 |
+
// Types
|
| 84 |
+
//
|
| 85 |
+
|
| 86 |
+
/// Params structure
|
| 87 |
+
struct Params {
|
| 88 |
+
|
| 89 |
+
MatrixCoord problem_size;
|
| 90 |
+
int partitions;
|
| 91 |
+
size_t partition_stride;
|
| 92 |
+
WorkspaceTensorRef workspace;
|
| 93 |
+
OutputTensorRef destination;
|
| 94 |
+
OutputTensorRef source;
|
| 95 |
+
typename OutputOp::Params output;
|
| 96 |
+
typename ReductionOp::Params reduction;
|
| 97 |
+
|
| 98 |
+
//
|
| 99 |
+
// Methods
|
| 100 |
+
//
|
| 101 |
+
|
| 102 |
+
CUTLASS_HOST_DEVICE
|
| 103 |
+
Params() { }
|
| 104 |
+
|
| 105 |
+
CUTLASS_HOST_DEVICE
|
| 106 |
+
Params(
|
| 107 |
+
MatrixCoord problem_size_,
|
| 108 |
+
int partitions_,
|
| 109 |
+
size_t partition_stride_,
|
| 110 |
+
WorkspaceTensorRef workspace_,
|
| 111 |
+
OutputTensorRef destination_,
|
| 112 |
+
OutputTensorRef source_,
|
| 113 |
+
typename OutputOp::Params output_ = typename OutputOp::Params(),
|
| 114 |
+
typename ReductionOp::Params reduction_ = typename ReductionOp::Params()
|
| 115 |
+
):
|
| 116 |
+
problem_size(problem_size_),
|
| 117 |
+
partitions(partitions_),
|
| 118 |
+
partition_stride(sizeof(FragmentWorkspace) * partition_stride_ / kElementsPerAccess),
|
| 119 |
+
workspace(workspace_),
|
| 120 |
+
destination(destination_),
|
| 121 |
+
source(source_),
|
| 122 |
+
output(output_),
|
| 123 |
+
reduction(reduction_) {
|
| 124 |
+
|
| 125 |
+
}
|
| 126 |
+
};
|
| 127 |
+
|
| 128 |
+
struct SharedStorage { };
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
public:
|
| 132 |
+
|
| 133 |
+
/// Computes the grid size given a chosen threadblock shape
|
| 134 |
+
CUTLASS_HOST_DEVICE
|
| 135 |
+
static dim3 grid_shape(
|
| 136 |
+
cutlass::MatrixCoord problem_size) {
|
| 137 |
+
|
| 138 |
+
return dim3(
|
| 139 |
+
(problem_size.row() + Shape::kRow - 1) / Shape::kRow,
|
| 140 |
+
(problem_size.column() + Shape::kColumn - 1) / Shape::kColumn);
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
/// Determines the threadblock shape
|
| 144 |
+
CUTLASS_HOST_DEVICE
|
| 145 |
+
static dim3 block_shape() {
|
| 146 |
+
return dim3(Shape::kColumn / kElementsPerAccess, Shape::kRow);
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
/// Perform a reduction
|
| 150 |
+
CUTLASS_DEVICE
|
| 151 |
+
void operator()(Params const ¶ms, SharedStorage &storage) {
|
| 152 |
+
|
| 153 |
+
// Determine CTA position
|
| 154 |
+
MatrixCoord thread_offset(
|
| 155 |
+
MatrixCoord::Index(int(blockIdx.x) * Shape::kRow + threadIdx.y),
|
| 156 |
+
MatrixCoord::Index(int(blockIdx.y) * Shape::kColumn + threadIdx.x * kElementsPerAccess)
|
| 157 |
+
);
|
| 158 |
+
|
| 159 |
+
// One guard conditional
|
| 160 |
+
if (!(thread_offset.row() < params.problem_size.row() &&
|
| 161 |
+
thread_offset.column() < params.problem_size.column())) {
|
| 162 |
+
|
| 163 |
+
return;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
ReductionOp reduction_op(params.reduction);
|
| 168 |
+
|
| 169 |
+
FragmentAccumulator accumulator;
|
| 170 |
+
|
| 171 |
+
accumulator.clear();
|
| 172 |
+
|
| 173 |
+
//
|
| 174 |
+
// Load the first slice
|
| 175 |
+
//
|
| 176 |
+
|
| 177 |
+
char const *workspace_ptr =
|
| 178 |
+
reinterpret_cast<char const *>(
|
| 179 |
+
params.workspace.data() + params.workspace.offset(thread_offset));
|
| 180 |
+
|
| 181 |
+
FragmentWorkspace workspace_frag[kPartitionsPerStage];
|
| 182 |
+
|
| 183 |
+
//
|
| 184 |
+
// Construct the output operator
|
| 185 |
+
//
|
| 186 |
+
|
| 187 |
+
OutputOp output_op(params.output);
|
| 188 |
+
|
| 189 |
+
//
|
| 190 |
+
// Load and accumulate with a simple batched loading sequence.
|
| 191 |
+
//
|
| 192 |
+
|
| 193 |
+
CUTLASS_PRAGMA_NO_UNROLL
|
| 194 |
+
for (int k = 0; k < params.partitions; k += kPartitionsPerStage) {
|
| 195 |
+
|
| 196 |
+
CUTLASS_PRAGMA_UNROLL
|
| 197 |
+
for (int i = 0; i < kPartitionsPerStage; ++i) {
|
| 198 |
+
if (k + i < params.partitions) {
|
| 199 |
+
workspace_frag[i] = *reinterpret_cast<FragmentWorkspace const *>(workspace_ptr);
|
| 200 |
+
workspace_ptr += params.partition_stride;
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
CUTLASS_PRAGMA_UNROLL
|
| 205 |
+
for (int i = 0; i < kPartitionsPerStage; ++i) {
|
| 206 |
+
if (k + i < params.partitions) {
|
| 207 |
+
accumulator = reduction_op(accumulator, workspace_frag[i]);
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
//
|
| 213 |
+
// Conditionally load the source
|
| 214 |
+
//
|
| 215 |
+
|
| 216 |
+
FragmentOutput source_frag;
|
| 217 |
+
|
| 218 |
+
source_frag.clear();
|
| 219 |
+
|
| 220 |
+
FragmentOutput const *source_ptr = reinterpret_cast<FragmentOutput const *>(
|
| 221 |
+
params.source.data() + params.source.offset(thread_offset));
|
| 222 |
+
|
| 223 |
+
if (output_op.is_source_needed()) {
|
| 224 |
+
reinterpret_cast<FragmentOutput &>(source_frag) = *source_ptr;
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
//
|
| 228 |
+
// Compute the output
|
| 229 |
+
//
|
| 230 |
+
|
| 231 |
+
typename OutputOp::FragmentOutput output_frag = output_op(accumulator, source_frag);
|
| 232 |
+
|
| 233 |
+
//
|
| 234 |
+
// Store
|
| 235 |
+
//
|
| 236 |
+
|
| 237 |
+
FragmentOutput *dest_ptr = reinterpret_cast<FragmentOutput *>(
|
| 238 |
+
params.destination.data() + params.destination.offset(thread_offset));
|
| 239 |
+
|
| 240 |
+
*dest_ptr = reinterpret_cast<FragmentOutput const &>(output_frag);
|
| 241 |
+
}
|
| 242 |
+
};
|
| 243 |
+
|
| 244 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 245 |
+
|
| 246 |
+
} // namespace kernel
|
| 247 |
+
} // namespace reduction
|
| 248 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h
ADDED
|
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Kernel performing a reduction over one or more ranks of an affine tensor
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/array.h"
|
| 39 |
+
#include "cutlass/fast_math.h"
|
| 40 |
+
#include "cutlass/numeric_types.h"
|
| 41 |
+
#include "cutlass/numeric_conversion.h"
|
| 42 |
+
#include "cutlass/device_kernel.h"
|
| 43 |
+
|
| 44 |
+
#include "cutlass/reduction/thread/reduction_operators.h"
|
| 45 |
+
|
| 46 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace reduction {
|
| 50 |
+
namespace kernel {
|
| 51 |
+
|
| 52 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
/// Parameters structure
|
| 55 |
+
template <
|
| 56 |
+
int Rank, ///< Rank of source tensor (e.g. NDHWC => 5)
|
| 57 |
+
int ReducedRank, ///< Rank of reduced tensor (i.e. number of outer ranks)
|
| 58 |
+
typename ElementOutput, ///< Data type of output tensor
|
| 59 |
+
typename ElementSource, ///< Data type of source tensor
|
| 60 |
+
typename ReductionOp, ///< Reduction operator
|
| 61 |
+
int VectorLength = 1, ///< Vector length for memory
|
| 62 |
+
typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation
|
| 63 |
+
int Threads = 256, ///< Number of participating threads
|
| 64 |
+
int BatchSize = 4 ///< Number of elements to load per batch
|
| 65 |
+
>
|
| 66 |
+
struct TensorReductionAffineContiguousParams {
|
| 67 |
+
|
| 68 |
+
static int const kRank = Rank;
|
| 69 |
+
static int const kReducedRank = ReducedRank;
|
| 70 |
+
static int const kVectorLength = VectorLength;
|
| 71 |
+
static int const kInnerRank = kRank - kReducedRank;
|
| 72 |
+
static int const kThreads = Threads;
|
| 73 |
+
static int const kBatchSize = BatchSize;
|
| 74 |
+
|
| 75 |
+
Coord<kRank> extent; /// Extent of source tensor
|
| 76 |
+
FastDivmodU64 divmod[kRank - 1]; /// FastDivmod by each strided rank
|
| 77 |
+
int64_t dst_stride[kReducedRank]; /// stride (units of bytes) - I, J
|
| 78 |
+
int64_t src_stride[kRank - 1]; /// stride (units of bytes) - I, J, K
|
| 79 |
+
int64_t workspace_stride; /// stride (units of bytes) between workspace
|
| 80 |
+
int workspace_count; /// number of workspaces
|
| 81 |
+
|
| 82 |
+
uint64_t inner_count; /// Number of elements in reduced index space
|
| 83 |
+
uint64_t outer_count; /// Number of elements in outer index space
|
| 84 |
+
|
| 85 |
+
ElementOutput * destination; /// Pointer to output tensor of rank kReducedRank
|
| 86 |
+
ElementSource const * source; /// Pointer to source pointer of rank kRank
|
| 87 |
+
ReductionOp reduction_op; /// Reduction operator
|
| 88 |
+
ElementCompute reduction_identity; /// Identity element used by reduction operator
|
| 89 |
+
ElementCompute *device_workspace; /// Pointer to device workspace for inter-CTA reductions
|
| 90 |
+
|
| 91 |
+
//
|
| 92 |
+
// Methods
|
| 93 |
+
//
|
| 94 |
+
|
| 95 |
+
/// Ctor
|
| 96 |
+
CUTLASS_HOST_DEVICE
|
| 97 |
+
TensorReductionAffineContiguousParams() {
|
| 98 |
+
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
/// Ctor
|
| 102 |
+
TensorReductionAffineContiguousParams(
|
| 103 |
+
Coord<kRank> extent_, ///< Extent of source tensor
|
| 104 |
+
ElementOutput * dst_ptr_, ///< Output tensor data
|
| 105 |
+
int64_t dst_stride_[], ///< Stride (units of elements)
|
| 106 |
+
ElementSource const * src_ptr_, ///< Source tensor data
|
| 107 |
+
int64_t src_stride_[], ///< Stride (units of elements)
|
| 108 |
+
ElementCompute *device_workspace_, ///< Pointer to device workspace for inter-CTA reductions
|
| 109 |
+
int64_t workspace_stride_, ///< Stride between workspaces
|
| 110 |
+
int workspace_count_, ///< Number of workspaces
|
| 111 |
+
ReductionOp reduction_op_, ///< Reduction operator
|
| 112 |
+
ElementCompute reduction_identity_ = ElementCompute() ///< Identity element used by reduction operator
|
| 113 |
+
):
|
| 114 |
+
extent(extent_),
|
| 115 |
+
inner_count(1),
|
| 116 |
+
outer_count(1),
|
| 117 |
+
destination(dst_ptr_),
|
| 118 |
+
source(src_ptr_),
|
| 119 |
+
device_workspace(device_workspace_),
|
| 120 |
+
workspace_stride(workspace_stride_),
|
| 121 |
+
workspace_count(workspace_count_),
|
| 122 |
+
reduction_op(reduction_op_),
|
| 123 |
+
reduction_identity(reduction_identity_) {
|
| 124 |
+
|
| 125 |
+
// Initialize divisors for fast div-mod
|
| 126 |
+
for (int p = 1; p < kRank; ++p) {
|
| 127 |
+
divmod[p - 1] = FastDivmodU64(uint64_t(extent[p]));
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
int input_size_bits = sizeof_bits<ElementSource>::value;
|
| 131 |
+
int output_size_bits = sizeof_bits<ElementOutput>::value;
|
| 132 |
+
|
| 133 |
+
// Compute strides in units of bytes
|
| 134 |
+
for (int p = 0; p < kReducedRank; ++p) {
|
| 135 |
+
dst_stride[p] = dst_stride_[p] * output_size_bits / 8;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
for (int p = 0; p < kRank - 1; ++p) {
|
| 139 |
+
src_stride[p] = src_stride_[p] * input_size_bits / 8;
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
// Compute number of elements in strided ranks
|
| 143 |
+
for (int p = 0; p < kReducedRank; ++p) {
|
| 144 |
+
outer_count *= uint64_t(extent[p]);
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
for (int p = 0; p < kInnerRank; ++p) {
|
| 148 |
+
inner_count *= uint64_t(extent[kRank - 1 - p]);
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
};
|
| 152 |
+
|
| 153 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 154 |
+
|
| 155 |
+
/// Kernel to reduce a tensor with affine layout over a set of ranks *INCLUDING* the contiguous
|
| 156 |
+
/// rank. This leads to favorable vectorized memory accesses over the contiguous rank.
|
| 157 |
+
template <
|
| 158 |
+
int Rank, ///< Rank of source tensor (e.g. NDHWC => 5)
|
| 159 |
+
int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2)
|
| 160 |
+
typename ElementOutput, ///< Data type of output tensor
|
| 161 |
+
typename ElementSource, ///< Data type of source tensor
|
| 162 |
+
typename ReductionOp, ///< Reduction operator
|
| 163 |
+
int VectorLength = 1, ///< Vector length for memory
|
| 164 |
+
typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation
|
| 165 |
+
int Threads = 256, ///< Number of participating threads
|
| 166 |
+
int BatchSize = 4 ///< Number of elements to load per batch
|
| 167 |
+
>
|
| 168 |
+
class TensorReductionAffineContiguous {
|
| 169 |
+
public:
|
| 170 |
+
|
| 171 |
+
static int const kRank = Rank;
|
| 172 |
+
static int const kReducedRank = ReducedRank;
|
| 173 |
+
static int const kVectorLength = VectorLength;
|
| 174 |
+
static int const kInnerRank = kRank - kReducedRank;
|
| 175 |
+
static int const kThreads = Threads;
|
| 176 |
+
static int const kBatchSize = BatchSize;
|
| 177 |
+
using ComputeFragment = Array<ElementCompute, VectorLength>;
|
| 178 |
+
using SourceFragment = AlignedArray<ElementSource, VectorLength>;
|
| 179 |
+
using OutputFragment = AlignedArray<ElementOutput, VectorLength>;
|
| 180 |
+
|
| 181 |
+
/// Shared memory allocation used for reduction within the CTA
|
| 182 |
+
struct SharedStorage {
|
| 183 |
+
Array<ElementCompute, kThreads * kVectorLength> workspace;
|
| 184 |
+
};
|
| 185 |
+
|
| 186 |
+
/// Parameters structure
|
| 187 |
+
using Params = TensorReductionAffineContiguousParams<
|
| 188 |
+
Rank,
|
| 189 |
+
ReducedRank,
|
| 190 |
+
ElementOutput,
|
| 191 |
+
ElementSource,
|
| 192 |
+
ReductionOp,
|
| 193 |
+
VectorLength,
|
| 194 |
+
ElementCompute,
|
| 195 |
+
Threads,
|
| 196 |
+
BatchSize
|
| 197 |
+
>;
|
| 198 |
+
|
| 199 |
+
private:
|
| 200 |
+
|
| 201 |
+
/// Computes the coordinate and offset of a given linear index
|
| 202 |
+
CUTLASS_DEVICE
|
| 203 |
+
void compute_inner_coord_and_offset_(
|
| 204 |
+
Params const ¶ms,
|
| 205 |
+
Coord<kInnerRank> & coord,
|
| 206 |
+
int64_t &src_offset,
|
| 207 |
+
uint64_t linear_idx) const {
|
| 208 |
+
|
| 209 |
+
// Decompose into a coordinate of rank <kInnerRank>
|
| 210 |
+
coord = CoordinateDecomposition<kInnerRank>(linear_idx, ¶ms.divmod[kRank - kInnerRank]);
|
| 211 |
+
|
| 212 |
+
// Compute an offset using the souce stride
|
| 213 |
+
src_offset = 0;
|
| 214 |
+
CUTLASS_PRAGMA_UNROLL
|
| 215 |
+
for (int i = 0; i < kInnerRank - 1; ++i) {
|
| 216 |
+
src_offset += coord[i] * params.src_stride[kReducedRank + i];
|
| 217 |
+
}
|
| 218 |
+
src_offset += coord[kInnerRank - 1] * sizeof_bits<ElementSource>::value / 8;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
/// Computes the coordinate and offset of a given linear index
|
| 222 |
+
CUTLASS_DEVICE
|
| 223 |
+
void compute_outer_coord_and_offset_(
|
| 224 |
+
Params const ¶ms,
|
| 225 |
+
Coord<kReducedRank> & coord,
|
| 226 |
+
int64_t &dst_offset,
|
| 227 |
+
int64_t &src_offset,
|
| 228 |
+
uint64_t linear_idx) const {
|
| 229 |
+
|
| 230 |
+
// Decompose into coordinate of rank <kReducedRank>
|
| 231 |
+
coord = CoordinateDecomposition<kReducedRank>(linear_idx, params.divmod);
|
| 232 |
+
|
| 233 |
+
// Compute offsets using destination and source strides
|
| 234 |
+
dst_offset = 0;
|
| 235 |
+
src_offset = 0;
|
| 236 |
+
|
| 237 |
+
CUTLASS_PRAGMA_UNROLL
|
| 238 |
+
for (int i = 0; i < kReducedRank; ++i) {
|
| 239 |
+
dst_offset += params.dst_stride[i] * coord[i];
|
| 240 |
+
src_offset += params.src_stride[i] * coord[i];
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
/// Reduces over the reduction indices yielding a single element
|
| 245 |
+
CUTLASS_DEVICE
|
| 246 |
+
ElementCompute reduce_indices_(
|
| 247 |
+
Params const ¶ms,
|
| 248 |
+
ElementCompute *threadblock_workspace,
|
| 249 |
+
char const *src_byte_ptr,
|
| 250 |
+
int coord_c) {
|
| 251 |
+
|
| 252 |
+
NumericArrayConverter<ElementCompute, ElementSource, VectorLength> convert_source;
|
| 253 |
+
ReductionOp reduction_op(params.reduction_op);
|
| 254 |
+
|
| 255 |
+
//
|
| 256 |
+
// Early exit or initialize to identity element
|
| 257 |
+
//
|
| 258 |
+
if (!params.inner_count) {
|
| 259 |
+
return params.reduction_identity;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
ComputeFragment accumulator;
|
| 263 |
+
|
| 264 |
+
CUTLASS_PRAGMA_UNROLL
|
| 265 |
+
for (int i = 0; i < int(accumulator.size()); ++i) {
|
| 266 |
+
accumulator[i] = params.reduction_identity;
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
// Compute the coordinate of the first access
|
| 270 |
+
int64_t src_byte_offset = 0;
|
| 271 |
+
Coord<kInnerRank> coord;
|
| 272 |
+
|
| 273 |
+
uint64_t linear_idx = (threadIdx.x + blockDim.x * threadIdx.z + blockDim.x * blockIdx.z * blockDim.z) * kVectorLength;
|
| 274 |
+
compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx);
|
| 275 |
+
|
| 276 |
+
// Load the first vector
|
| 277 |
+
SourceFragment source_fragment[kBatchSize];
|
| 278 |
+
|
| 279 |
+
bool not_done = true;
|
| 280 |
+
|
| 281 |
+
// Iterate over vectors in a linearized reduction index space
|
| 282 |
+
while (not_done) {
|
| 283 |
+
|
| 284 |
+
bool guards[kBatchSize];
|
| 285 |
+
|
| 286 |
+
// Issue a batch of loads
|
| 287 |
+
CUTLASS_PRAGMA_UNROLL
|
| 288 |
+
for (int b = 0; b < kBatchSize; ++b) {
|
| 289 |
+
|
| 290 |
+
if (linear_idx < params.inner_count) {
|
| 291 |
+
source_fragment[b] = *reinterpret_cast<SourceFragment const *>(src_byte_ptr + src_byte_offset);
|
| 292 |
+
guards[b] = true;
|
| 293 |
+
}
|
| 294 |
+
else {
|
| 295 |
+
guards[b] = false;
|
| 296 |
+
not_done = false;
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
linear_idx += (blockDim.z * gridDim.z * blockDim.x) * kVectorLength;
|
| 300 |
+
compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx);
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
// Perform a batch of reduction operations
|
| 304 |
+
CUTLASS_PRAGMA_UNROLL
|
| 305 |
+
for (int b = 0; b < kBatchSize; ++b) {
|
| 306 |
+
if (guards[b]) {
|
| 307 |
+
auto cvt = convert_source(source_fragment[b]);
|
| 308 |
+
|
| 309 |
+
accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator(
|
| 310 |
+
reduction_op,
|
| 311 |
+
accumulator,
|
| 312 |
+
cvt);
|
| 313 |
+
}
|
| 314 |
+
}
|
| 315 |
+
};
|
| 316 |
+
|
| 317 |
+
//
|
| 318 |
+
// Reduction of vectors to scalar
|
| 319 |
+
//
|
| 320 |
+
|
| 321 |
+
ElementCompute reduced_accumulator = accumulator[0];
|
| 322 |
+
|
| 323 |
+
CUTLASS_PRAGMA_UNROLL
|
| 324 |
+
for (int i = 1; i < kVectorLength; ++i) {
|
| 325 |
+
reduced_accumulator = reduction_op(reduced_accumulator, accumulator[i]);
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
//
|
| 329 |
+
// Reduction within CTA across threadIdx.xz => threadIdx{.x = 0, .z = 0}
|
| 330 |
+
//
|
| 331 |
+
// This re-arranges data so threadIdx.y is effectively a row index and threadIdx.xz is a column
|
| 332 |
+
//
|
| 333 |
+
|
| 334 |
+
int thread_count = blockDim.x * blockDim.z;
|
| 335 |
+
int thread_j = threadIdx.x + blockDim.x * threadIdx.z;
|
| 336 |
+
int thread_i = threadIdx.y;
|
| 337 |
+
|
| 338 |
+
ElementCompute *frag_ptr = reinterpret_cast<ElementCompute *>(threadblock_workspace) + thread_i * thread_count;
|
| 339 |
+
|
| 340 |
+
frag_ptr[thread_j] = reduced_accumulator;
|
| 341 |
+
|
| 342 |
+
//
|
| 343 |
+
// Reduce
|
| 344 |
+
//
|
| 345 |
+
CUTLASS_PRAGMA_NO_UNROLL
|
| 346 |
+
while (thread_count > 1) {
|
| 347 |
+
thread_count /= 2;
|
| 348 |
+
|
| 349 |
+
__syncthreads();
|
| 350 |
+
|
| 351 |
+
if (thread_j < thread_count) {
|
| 352 |
+
ElementCompute other = frag_ptr[thread_j + thread_count];
|
| 353 |
+
|
| 354 |
+
reduced_accumulator = reduction_op(reduced_accumulator, other);
|
| 355 |
+
|
| 356 |
+
frag_ptr[thread_j] = reduced_accumulator;
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
__syncthreads();
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
return reduced_accumulator;
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
public:
|
| 367 |
+
|
| 368 |
+
/// Perform a reduction
|
| 369 |
+
CUTLASS_DEVICE
|
| 370 |
+
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
| 371 |
+
|
| 372 |
+
int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength;
|
| 373 |
+
|
| 374 |
+
char const * src_byte_ptr = reinterpret_cast<char const *>(params.source);
|
| 375 |
+
char * dst_byte_ptr = nullptr;
|
| 376 |
+
|
| 377 |
+
// If performing a reduction across CTAs, redirect output to device workspace
|
| 378 |
+
if (gridDim.z == 1) {
|
| 379 |
+
dst_byte_ptr = reinterpret_cast<char *>(params.destination);
|
| 380 |
+
}
|
| 381 |
+
else {
|
| 382 |
+
dst_byte_ptr = reinterpret_cast<char *>(params.device_workspace);
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
uint64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y;
|
| 386 |
+
|
| 387 |
+
// Use modulo division to compute location
|
| 388 |
+
Coord<kReducedRank> outer_coord;
|
| 389 |
+
int64_t dst_byte_offset;
|
| 390 |
+
int64_t src_byte_offset;
|
| 391 |
+
|
| 392 |
+
compute_outer_coord_and_offset_(
|
| 393 |
+
params,
|
| 394 |
+
outer_coord,
|
| 395 |
+
dst_byte_offset,
|
| 396 |
+
src_byte_offset,
|
| 397 |
+
idx_linear);
|
| 398 |
+
|
| 399 |
+
if (gridDim.z == 1) {
|
| 400 |
+
|
| 401 |
+
/// Complete the reduction with no workspace
|
| 402 |
+
while (idx_linear < params.outer_count) {
|
| 403 |
+
|
| 404 |
+
ElementCompute result = reduce_indices_(
|
| 405 |
+
params,
|
| 406 |
+
shared_storage.workspace.data(),
|
| 407 |
+
src_byte_ptr + src_byte_offset,
|
| 408 |
+
coord_c);
|
| 409 |
+
|
| 410 |
+
// Store the result after possible final reduction within the CTA
|
| 411 |
+
if (threadIdx.z == 0 && threadIdx.x == 0) {
|
| 412 |
+
|
| 413 |
+
// Convert to output type and store
|
| 414 |
+
NumericConverter<ElementOutput, ElementCompute> convert_output;
|
| 415 |
+
ElementOutput cvt = convert_output(result);
|
| 416 |
+
|
| 417 |
+
*reinterpret_cast<ElementOutput *>(dst_byte_ptr + dst_byte_offset) = cvt;
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
__syncthreads();
|
| 421 |
+
|
| 422 |
+
// Update indices and pointers
|
| 423 |
+
idx_linear += gridDim.y * blockDim.y;
|
| 424 |
+
|
| 425 |
+
compute_outer_coord_and_offset_(
|
| 426 |
+
params,
|
| 427 |
+
outer_coord,
|
| 428 |
+
dst_byte_offset,
|
| 429 |
+
src_byte_offset,
|
| 430 |
+
idx_linear);
|
| 431 |
+
|
| 432 |
+
} // while
|
| 433 |
+
}
|
| 434 |
+
else {
|
| 435 |
+
|
| 436 |
+
/// Complete the reduction with workspace
|
| 437 |
+
while (idx_linear < params.outer_count) {
|
| 438 |
+
|
| 439 |
+
ElementCompute result = reduce_indices_(
|
| 440 |
+
params,
|
| 441 |
+
shared_storage.workspace.data(),
|
| 442 |
+
src_byte_ptr + src_byte_offset,
|
| 443 |
+
coord_c);
|
| 444 |
+
|
| 445 |
+
int64_t byte_offset =
|
| 446 |
+
blockIdx.z * params.workspace_stride + idx_linear * sizeof_bits<ElementCompute>::value / 8;
|
| 447 |
+
|
| 448 |
+
// Store the result for final reduction
|
| 449 |
+
if (threadIdx.z == 0 && threadIdx.x == 0) {
|
| 450 |
+
*reinterpret_cast<ElementCompute *>(dst_byte_ptr + byte_offset) = result;
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
__syncthreads();
|
| 454 |
+
|
| 455 |
+
// Update indices and pointers
|
| 456 |
+
idx_linear += gridDim.y * blockDim.y;
|
| 457 |
+
|
| 458 |
+
compute_outer_coord_and_offset_(
|
| 459 |
+
params,
|
| 460 |
+
outer_coord,
|
| 461 |
+
dst_byte_offset,
|
| 462 |
+
src_byte_offset,
|
| 463 |
+
idx_linear);
|
| 464 |
+
} // while
|
| 465 |
+
}
|
| 466 |
+
}
|
| 467 |
+
};
|
| 468 |
+
|
| 469 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 470 |
+
|
| 471 |
+
/// Kernel to perform final reduction
|
| 472 |
+
template <
|
| 473 |
+
int Rank, ///< Rank of source tensor (e.g. NDHWC => 5)
|
| 474 |
+
int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2)
|
| 475 |
+
typename ElementOutput, ///< Data type of output tensor
|
| 476 |
+
typename ElementSource, ///< Data type of source tensor
|
| 477 |
+
typename ReductionOp, ///< Reduction operator
|
| 478 |
+
int VectorLength = 1, ///< Vector length for memory
|
| 479 |
+
typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation
|
| 480 |
+
int Threads = 256, ///< Number of participating threads
|
| 481 |
+
int BatchSize = 4 ///< Number of elements to load per batch
|
| 482 |
+
>
|
| 483 |
+
class TensorReductionAffineContiguousFinal {
|
| 484 |
+
public:
|
| 485 |
+
|
| 486 |
+
static int const kRank = Rank;
|
| 487 |
+
static int const kReducedRank = ReducedRank;
|
| 488 |
+
static int const kVectorLength = VectorLength;
|
| 489 |
+
static int const kInnerRank = kRank - kReducedRank;
|
| 490 |
+
static int const kThreads = Threads;
|
| 491 |
+
static int const kBatchSize = BatchSize;
|
| 492 |
+
|
| 493 |
+
/// Shared memory
|
| 494 |
+
struct SharedStorage { };
|
| 495 |
+
|
| 496 |
+
/// Parameters structure
|
| 497 |
+
using Params = TensorReductionAffineContiguousParams<
|
| 498 |
+
Rank,
|
| 499 |
+
ReducedRank,
|
| 500 |
+
ElementOutput,
|
| 501 |
+
ElementSource,
|
| 502 |
+
ReductionOp,
|
| 503 |
+
VectorLength,
|
| 504 |
+
ElementCompute,
|
| 505 |
+
Threads,
|
| 506 |
+
BatchSize
|
| 507 |
+
>;
|
| 508 |
+
|
| 509 |
+
private:
|
| 510 |
+
|
| 511 |
+
/// Computes the coordinate and offset of a given linear index
|
| 512 |
+
CUTLASS_DEVICE
|
| 513 |
+
void compute_outer_coord_and_offset_(
|
| 514 |
+
Params const ¶ms,
|
| 515 |
+
Coord<kReducedRank> & coord,
|
| 516 |
+
int64_t &dst_offset,
|
| 517 |
+
uint64_t linear_idx) const {
|
| 518 |
+
|
| 519 |
+
// Decompose into coordinate of rank <kReducedRank>
|
| 520 |
+
coord = CoordinateDecomposition<kReducedRank>(linear_idx, params.divmod);
|
| 521 |
+
|
| 522 |
+
// Compute offsets using destination and source strides
|
| 523 |
+
dst_offset = 0;
|
| 524 |
+
|
| 525 |
+
CUTLASS_PRAGMA_UNROLL
|
| 526 |
+
for (int i = 0; i < kReducedRank; ++i) {
|
| 527 |
+
dst_offset += params.dst_stride[i] * coord[i];
|
| 528 |
+
}
|
| 529 |
+
}
|
| 530 |
+
|
| 531 |
+
/// Reduces over the reduction indices
|
| 532 |
+
CUTLASS_DEVICE
|
| 533 |
+
ElementCompute reduce_indices_(
|
| 534 |
+
Params const ¶ms,
|
| 535 |
+
ElementCompute const *device_workspace) {
|
| 536 |
+
|
| 537 |
+
ReductionOp reduction_op(params.reduction_op);
|
| 538 |
+
char const *src_byte_ptr = reinterpret_cast<char const *>(device_workspace);
|
| 539 |
+
|
| 540 |
+
// Accumulated output
|
| 541 |
+
ElementCompute accumulator = params.reduction_identity;
|
| 542 |
+
|
| 543 |
+
for (int iter = 0; iter < params.workspace_count; ++iter) {
|
| 544 |
+
ElementCompute workspace_item = *reinterpret_cast<ElementCompute const *>(src_byte_ptr);
|
| 545 |
+
|
| 546 |
+
accumulator = reduction_op(accumulator, workspace_item);
|
| 547 |
+
|
| 548 |
+
src_byte_ptr += params.workspace_stride;
|
| 549 |
+
}
|
| 550 |
+
|
| 551 |
+
return accumulator;
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
public:
|
| 555 |
+
|
| 556 |
+
//
|
| 557 |
+
// Methods
|
| 558 |
+
//
|
| 559 |
+
|
| 560 |
+
/// Perform a reduction
|
| 561 |
+
CUTLASS_DEVICE
|
| 562 |
+
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
| 563 |
+
|
| 564 |
+
uint64_t idx_linear = blockIdx.x * blockDim.x + threadIdx.x;
|
| 565 |
+
|
| 566 |
+
char * dst_byte_ptr = reinterpret_cast<char *>(params.destination);
|
| 567 |
+
|
| 568 |
+
// Use modulo division to compute location
|
| 569 |
+
Coord<kReducedRank> outer_coord;
|
| 570 |
+
int64_t dst_byte_offset;
|
| 571 |
+
|
| 572 |
+
compute_outer_coord_and_offset_(
|
| 573 |
+
params,
|
| 574 |
+
outer_coord,
|
| 575 |
+
dst_byte_offset,
|
| 576 |
+
idx_linear);
|
| 577 |
+
|
| 578 |
+
/// Complete the reduction
|
| 579 |
+
while (idx_linear < params.outer_count) {
|
| 580 |
+
|
| 581 |
+
ElementCompute result = reduce_indices_(params, params.device_workspace + idx_linear);
|
| 582 |
+
|
| 583 |
+
// Convert to output type and store
|
| 584 |
+
NumericConverter<ElementOutput, ElementCompute> convert_output;
|
| 585 |
+
|
| 586 |
+
*reinterpret_cast<ElementOutput *>(dst_byte_ptr + dst_byte_offset) = convert_output(result);
|
| 587 |
+
|
| 588 |
+
// Update indices and pointers
|
| 589 |
+
idx_linear += gridDim.x * blockDim.x;
|
| 590 |
+
|
| 591 |
+
compute_outer_coord_and_offset_(
|
| 592 |
+
params,
|
| 593 |
+
outer_coord,
|
| 594 |
+
dst_byte_offset,
|
| 595 |
+
idx_linear);
|
| 596 |
+
}
|
| 597 |
+
}
|
| 598 |
+
};
|
| 599 |
+
|
| 600 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 601 |
+
|
| 602 |
+
} // namespace kernel
|
| 603 |
+
} // namespace reduction
|
| 604 |
+
} // namespace cutlass
|
| 605 |
+
|
| 606 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h
ADDED
|
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Kernel performing a reduction over one or more ranks of an affine tensor
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/array.h"
|
| 39 |
+
#include "cutlass/fast_math.h"
|
| 40 |
+
#include "cutlass/numeric_types.h"
|
| 41 |
+
#include "cutlass/numeric_conversion.h"
|
| 42 |
+
#include "cutlass/device_kernel.h"
|
| 43 |
+
|
| 44 |
+
#include "cutlass/reduction/thread/reduction_operators.h"
|
| 45 |
+
|
| 46 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace reduction {
|
| 50 |
+
|
| 51 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 52 |
+
|
| 53 |
+
namespace kernel {
|
| 54 |
+
|
| 55 |
+
/// Parameters structure
|
| 56 |
+
template <
|
| 57 |
+
int Rank, ///< Rank of source tensor (e.g. NDHWC => 5)
|
| 58 |
+
int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2)
|
| 59 |
+
typename ElementOutput, ///< Data type of output tensor
|
| 60 |
+
typename ElementSource, ///< Data type of source tensor
|
| 61 |
+
typename ReductionOp, ///< Reduction operator
|
| 62 |
+
int VectorLength = 1, ///< Vector length for memory
|
| 63 |
+
typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation
|
| 64 |
+
int Threads = 256, ///< Number of participating threads
|
| 65 |
+
int BatchSize = 4 ///< Number of elements to load per batch
|
| 66 |
+
>
|
| 67 |
+
struct TensorReductionAffineStridedParams {
|
| 68 |
+
|
| 69 |
+
static int const kRank = Rank;
|
| 70 |
+
static int const kReducedRank = ReducedRank;
|
| 71 |
+
static int const kVectorLength = VectorLength;
|
| 72 |
+
static int const kInnerRank = kRank - kReducedRank;
|
| 73 |
+
static int const kThreads = Threads;
|
| 74 |
+
static int const kBatchSize = BatchSize;
|
| 75 |
+
|
| 76 |
+
Coord<kRank> extent; /// Extent of source tensor
|
| 77 |
+
FastDivmodU64 divmod[kRank - 1]; /// FastDivmod by each strided rank
|
| 78 |
+
int64_t dst_stride[kReducedRank - 1]; /// stride (units of bytes) - I, J
|
| 79 |
+
int64_t src_stride[kRank - 1]; /// stride (units of bytes) - I, J, K
|
| 80 |
+
int64_t workspace_stride; /// stride (units of bytes) between workspace
|
| 81 |
+
int64_t workspace_outer_stride; /// stride (units of bytes) between 'rows' of the workspace
|
| 82 |
+
int workspace_count; /// number of workspaces
|
| 83 |
+
|
| 84 |
+
uint64_t inner_count; /// Number of elements in reduced index space
|
| 85 |
+
uint64_t outer_count; /// Number of elements in outer index space
|
| 86 |
+
|
| 87 |
+
ElementOutput * destination; /// Pointer to output tensor of rank kReducedRank
|
| 88 |
+
ElementSource const * source; /// Pointer to source pointer of rank kRank
|
| 89 |
+
ReductionOp reduction_op; /// Reduction operator
|
| 90 |
+
ElementCompute reduction_identity; /// Identity element for reduction operator
|
| 91 |
+
ElementCompute *device_workspace; /// Pointer to device workspace for inter-CTA reductions
|
| 92 |
+
|
| 93 |
+
//
|
| 94 |
+
// Methods
|
| 95 |
+
//
|
| 96 |
+
|
| 97 |
+
/// Ctor
|
| 98 |
+
CUTLASS_HOST_DEVICE
|
| 99 |
+
TensorReductionAffineStridedParams() {
|
| 100 |
+
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
/// Ctor
|
| 104 |
+
TensorReductionAffineStridedParams(
|
| 105 |
+
Coord<kRank> extent_, ///< Extent of source tensor
|
| 106 |
+
ElementOutput * dst_ptr_, ///< Output tensor data
|
| 107 |
+
int64_t dst_stride_[], ///< Stride (units of elements)
|
| 108 |
+
ElementSource const * src_ptr_, ///< Source tensor data
|
| 109 |
+
int64_t src_stride_[], ///< Stride (units of elements)
|
| 110 |
+
ElementCompute *device_workspace_, ///< Pointer to device workspace for inter-CTA reductions
|
| 111 |
+
int64_t workspace_stride_, ///< Stride between workspaces
|
| 112 |
+
int workspace_count_, ///< Number of workspaces
|
| 113 |
+
ReductionOp reduction_op_, ///< Reduction operator
|
| 114 |
+
ElementCompute reduction_identity_ = ElementCompute() ///< Identity element for reduction operator
|
| 115 |
+
):
|
| 116 |
+
extent(extent_),
|
| 117 |
+
inner_count(1),
|
| 118 |
+
outer_count(1),
|
| 119 |
+
destination(dst_ptr_),
|
| 120 |
+
source(src_ptr_),
|
| 121 |
+
device_workspace(device_workspace_),
|
| 122 |
+
workspace_outer_stride(0),
|
| 123 |
+
workspace_stride(workspace_stride_),
|
| 124 |
+
workspace_count(workspace_count_),
|
| 125 |
+
reduction_op(reduction_op_),
|
| 126 |
+
reduction_identity(reduction_identity_) {
|
| 127 |
+
|
| 128 |
+
// Initialize divisors for fast div-mod
|
| 129 |
+
for (int p = 1; p < kRank; ++p) {
|
| 130 |
+
divmod[p - 1] = FastDivmodU64(uint64_t(extent[p]));
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
int input_size_bits = sizeof_bits<ElementSource>::value;
|
| 134 |
+
int output_size_bits = sizeof_bits<ElementOutput>::value;
|
| 135 |
+
|
| 136 |
+
workspace_outer_stride = workspace_stride * workspace_count;
|
| 137 |
+
|
| 138 |
+
// Compute strides in units of bytes
|
| 139 |
+
for (int p = 0; p < kReducedRank - 1; ++p) {
|
| 140 |
+
dst_stride[p] = dst_stride_[p] * output_size_bits / 8;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
for (int p = 0; p < kRank - 1; ++p) {
|
| 144 |
+
src_stride[p] = src_stride_[p] * input_size_bits / 8;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
// Compute number of elements in strided ranks
|
| 148 |
+
for (int p = 0; p < kReducedRank - 1; ++p) {
|
| 149 |
+
outer_count *= uint64_t(extent[p]);
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
for (int p = 0; p < kInnerRank; ++p) {
|
| 153 |
+
inner_count *= uint64_t(extent[kReducedRank + p - 1]);
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
};
|
| 157 |
+
|
| 158 |
+
/// Kernel to reduce a tensor with affine layout over a set of ranks *EXCLUDING* the contiguous
|
| 159 |
+
/// rank. This leads to favorable vectorized memory accesses over the contiguous rank.
|
| 160 |
+
template <
|
| 161 |
+
int Rank, ///< Rank of source tensor (e.g. NDHWC => 5)
|
| 162 |
+
int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2)
|
| 163 |
+
typename ElementOutput, ///< Data type of output tensor
|
| 164 |
+
typename ElementSource, ///< Data type of source tensor
|
| 165 |
+
typename ReductionOp, ///< Reduction operator
|
| 166 |
+
int VectorLength = 1, ///< Vector length for memory
|
| 167 |
+
typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation
|
| 168 |
+
int Threads = 256, ///< Number of participating threads
|
| 169 |
+
int BatchSize = 4 ///< Number of elements to load per batch
|
| 170 |
+
>
|
| 171 |
+
class TensorReductionAffineStrided {
|
| 172 |
+
public:
|
| 173 |
+
|
| 174 |
+
static int const kRank = Rank;
|
| 175 |
+
static int const kReducedRank = ReducedRank;
|
| 176 |
+
static int const kVectorLength = VectorLength;
|
| 177 |
+
static int const kInnerRank = kRank - kReducedRank;
|
| 178 |
+
static int const kThreads = Threads;
|
| 179 |
+
static int const kBatchSize = BatchSize;
|
| 180 |
+
using ComputeFragment = Array<ElementCompute, VectorLength>;
|
| 181 |
+
using SourceFragment = AlignedArray<ElementSource, VectorLength>;
|
| 182 |
+
using OutputFragment = AlignedArray<ElementOutput, VectorLength>;
|
| 183 |
+
|
| 184 |
+
/// Shared memory allocation used for reduction within the CTA
|
| 185 |
+
struct SharedStorage {
|
| 186 |
+
Array<ElementCompute, kThreads * kVectorLength> workspace;
|
| 187 |
+
};
|
| 188 |
+
|
| 189 |
+
/// Parameters structure
|
| 190 |
+
using Params = TensorReductionAffineStridedParams<
|
| 191 |
+
Rank,
|
| 192 |
+
ReducedRank,
|
| 193 |
+
ElementOutput,
|
| 194 |
+
ElementSource,
|
| 195 |
+
ReductionOp,
|
| 196 |
+
VectorLength,
|
| 197 |
+
ElementCompute,
|
| 198 |
+
Threads,
|
| 199 |
+
BatchSize
|
| 200 |
+
>;
|
| 201 |
+
|
| 202 |
+
private:
|
| 203 |
+
|
| 204 |
+
/// Computes the coordinate and offset of a given linear index
|
| 205 |
+
CUTLASS_DEVICE
|
| 206 |
+
void compute_inner_coord_and_offset_(
|
| 207 |
+
Params const ¶ms,
|
| 208 |
+
Coord<kInnerRank> & coord,
|
| 209 |
+
int64_t &src_offset,
|
| 210 |
+
uint64_t linear_idx) const {
|
| 211 |
+
|
| 212 |
+
// Decompose into coordinate
|
| 213 |
+
coord = CoordinateDecomposition<kInnerRank>(linear_idx, ¶ms.divmod[kReducedRank - 1]);
|
| 214 |
+
|
| 215 |
+
// Compute linear offset
|
| 216 |
+
src_offset = 0;
|
| 217 |
+
|
| 218 |
+
CUTLASS_PRAGMA_UNROLL
|
| 219 |
+
for (int i = 0; i < kInnerRank; ++i) {
|
| 220 |
+
src_offset += params.src_stride[kReducedRank + i - 1] * coord[i];
|
| 221 |
+
}
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
/// Computes the coordinate and offset of a given linear index
|
| 225 |
+
CUTLASS_DEVICE
|
| 226 |
+
void compute_outer_coord_and_offset_(
|
| 227 |
+
Params const ¶ms,
|
| 228 |
+
Coord<kReducedRank - 1> & coord,
|
| 229 |
+
int64_t &dst_offset,
|
| 230 |
+
int64_t &src_offset,
|
| 231 |
+
uint64_t linear_idx) const {
|
| 232 |
+
|
| 233 |
+
// Decompose linear coordinate
|
| 234 |
+
coord = CoordinateDecomposition<kReducedRank - 1>(linear_idx, params.divmod);
|
| 235 |
+
|
| 236 |
+
// Compute offset into tensors
|
| 237 |
+
dst_offset = 0;
|
| 238 |
+
src_offset = 0;
|
| 239 |
+
|
| 240 |
+
CUTLASS_PRAGMA_UNROLL
|
| 241 |
+
for (int i = 0; i < kReducedRank - 1; ++i) {
|
| 242 |
+
dst_offset += params.dst_stride[i] * coord[i];
|
| 243 |
+
src_offset += params.src_stride[i] * coord[i];
|
| 244 |
+
}
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
/// Reduces over the reduction indices
|
| 248 |
+
CUTLASS_DEVICE
|
| 249 |
+
ComputeFragment reduce_indices_(
|
| 250 |
+
Params const ¶ms,
|
| 251 |
+
ElementCompute *threadblock_workspace,
|
| 252 |
+
char const *src_byte_ptr) {
|
| 253 |
+
|
| 254 |
+
NumericArrayConverter<ElementCompute, ElementSource, VectorLength> convert_source;
|
| 255 |
+
ReductionOp reduction_op(params.reduction_op);
|
| 256 |
+
|
| 257 |
+
// Accumulated output
|
| 258 |
+
ComputeFragment identity_frag;
|
| 259 |
+
|
| 260 |
+
CUTLASS_PRAGMA_UNROLL
|
| 261 |
+
for (int i = 0; i < int(identity_frag.size()); ++i) {
|
| 262 |
+
identity_frag[i] = params.reduction_identity;
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
if (!params.inner_count) {
|
| 266 |
+
return identity_frag;
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
ComputeFragment accumulator = identity_frag;
|
| 270 |
+
|
| 271 |
+
// Compute the coordinate of the first access
|
| 272 |
+
int64_t src_byte_offset = 0;
|
| 273 |
+
Coord<kInnerRank> coord;
|
| 274 |
+
|
| 275 |
+
uint64_t linear_idx = threadIdx.z + blockIdx.z * blockDim.z;
|
| 276 |
+
compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx);
|
| 277 |
+
|
| 278 |
+
// Load the first vector
|
| 279 |
+
SourceFragment source_fragment[kBatchSize];
|
| 280 |
+
|
| 281 |
+
bool not_done = true;
|
| 282 |
+
|
| 283 |
+
// Iterate over vectors in a linearized reduction index space
|
| 284 |
+
while (not_done) {
|
| 285 |
+
|
| 286 |
+
bool guards[kBatchSize];
|
| 287 |
+
|
| 288 |
+
// Issue a batch of loads
|
| 289 |
+
CUTLASS_PRAGMA_UNROLL
|
| 290 |
+
for (int b = 0; b < kBatchSize; ++b) {
|
| 291 |
+
|
| 292 |
+
if (linear_idx < params.inner_count) {
|
| 293 |
+
source_fragment[b] = *reinterpret_cast<SourceFragment const *>(src_byte_ptr + src_byte_offset);
|
| 294 |
+
guards[b] = true;
|
| 295 |
+
}
|
| 296 |
+
else {
|
| 297 |
+
guards[b] = false;
|
| 298 |
+
not_done = false;
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
linear_idx += blockDim.z * gridDim.z;
|
| 302 |
+
compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx);
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
// Perform a batch of reduction operations
|
| 306 |
+
CUTLASS_PRAGMA_UNROLL
|
| 307 |
+
for (int b = 0; b < kBatchSize; ++b) {
|
| 308 |
+
if (guards[b]) {
|
| 309 |
+
|
| 310 |
+
auto cvt = convert_source(source_fragment[b]);
|
| 311 |
+
|
| 312 |
+
accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator(
|
| 313 |
+
reduction_op,
|
| 314 |
+
accumulator,
|
| 315 |
+
cvt);
|
| 316 |
+
}
|
| 317 |
+
}
|
| 318 |
+
};
|
| 319 |
+
|
| 320 |
+
// Optional reduction within a CTA
|
| 321 |
+
if (blockDim.z > 1) {
|
| 322 |
+
|
| 323 |
+
// Linearized thread ID
|
| 324 |
+
int thread_idx = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z);
|
| 325 |
+
|
| 326 |
+
// all threads store to workspace
|
| 327 |
+
ComputeFragment *frag_ptr = reinterpret_cast<ComputeFragment *>(threadblock_workspace);
|
| 328 |
+
|
| 329 |
+
frag_ptr[thread_idx] = accumulator;
|
| 330 |
+
|
| 331 |
+
__syncthreads();
|
| 332 |
+
|
| 333 |
+
if (threadIdx.z == 0) {
|
| 334 |
+
// Load all additional block indices
|
| 335 |
+
for (int z = 1; z < blockDim.z; ++z) {
|
| 336 |
+
ComputeFragment frag = frag_ptr[thread_idx + z * blockDim.x * blockDim.y];
|
| 337 |
+
|
| 338 |
+
accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator(
|
| 339 |
+
reduction_op,
|
| 340 |
+
accumulator,
|
| 341 |
+
frag);
|
| 342 |
+
}
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
__syncthreads();
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
return accumulator;
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
public:
|
| 352 |
+
|
| 353 |
+
/// Perform a reduction
|
| 354 |
+
CUTLASS_DEVICE
|
| 355 |
+
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
| 356 |
+
|
| 357 |
+
int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength;
|
| 358 |
+
|
| 359 |
+
char const * src_byte_ptr = reinterpret_cast<char const *>(params.source + coord_c);
|
| 360 |
+
char * dst_byte_ptr = nullptr;
|
| 361 |
+
|
| 362 |
+
// If performing a reduction across CTAs, redirect output to device workspace
|
| 363 |
+
if (gridDim.z == 1) {
|
| 364 |
+
dst_byte_ptr = reinterpret_cast<char *>(params.destination + coord_c);
|
| 365 |
+
}
|
| 366 |
+
else {
|
| 367 |
+
dst_byte_ptr = reinterpret_cast<char *>(params.device_workspace + coord_c);
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
// If the C index is out of bounds, exit
|
| 371 |
+
if (coord_c >= params.extent[kRank - 1]) {
|
| 372 |
+
return;
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
int64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y;
|
| 376 |
+
|
| 377 |
+
// Use modulo division to compute location
|
| 378 |
+
Coord<kReducedRank - 1> outer_coord;
|
| 379 |
+
int64_t dst_byte_offset;
|
| 380 |
+
int64_t src_byte_offset;
|
| 381 |
+
|
| 382 |
+
compute_outer_coord_and_offset_(
|
| 383 |
+
params,
|
| 384 |
+
outer_coord,
|
| 385 |
+
dst_byte_offset,
|
| 386 |
+
src_byte_offset,
|
| 387 |
+
idx_linear);
|
| 388 |
+
|
| 389 |
+
if (gridDim.z == 1) {
|
| 390 |
+
|
| 391 |
+
/// Complete the reduction with no workspace
|
| 392 |
+
while (idx_linear < params.outer_count) {
|
| 393 |
+
|
| 394 |
+
ComputeFragment result;
|
| 395 |
+
|
| 396 |
+
result = reduce_indices_(
|
| 397 |
+
params,
|
| 398 |
+
shared_storage.workspace.data(),
|
| 399 |
+
src_byte_ptr + src_byte_offset);
|
| 400 |
+
|
| 401 |
+
// Store the result after possible final reduction within the CTA
|
| 402 |
+
if (threadIdx.z == 0) {
|
| 403 |
+
|
| 404 |
+
// Convert to output type and store
|
| 405 |
+
NumericArrayConverter<ElementOutput, ElementCompute, VectorLength> convert_output;
|
| 406 |
+
auto cvt = convert_output(result);
|
| 407 |
+
|
| 408 |
+
*reinterpret_cast<OutputFragment *>(dst_byte_ptr + dst_byte_offset) =
|
| 409 |
+
reinterpret_cast<OutputFragment const &>(cvt);
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
// Update indices and pointers
|
| 413 |
+
idx_linear += gridDim.y * blockDim.y;
|
| 414 |
+
|
| 415 |
+
compute_outer_coord_and_offset_(
|
| 416 |
+
params,
|
| 417 |
+
outer_coord,
|
| 418 |
+
dst_byte_offset,
|
| 419 |
+
src_byte_offset,
|
| 420 |
+
idx_linear);
|
| 421 |
+
|
| 422 |
+
} // while
|
| 423 |
+
}
|
| 424 |
+
else {
|
| 425 |
+
|
| 426 |
+
/// Complete the reduction with a device workspace
|
| 427 |
+
while (idx_linear < params.outer_count) {
|
| 428 |
+
|
| 429 |
+
ComputeFragment result;
|
| 430 |
+
|
| 431 |
+
result = reduce_indices_(
|
| 432 |
+
params,
|
| 433 |
+
shared_storage.workspace.data(),
|
| 434 |
+
src_byte_ptr + src_byte_offset);
|
| 435 |
+
|
| 436 |
+
// Store the result after possible final reduction within the CTA
|
| 437 |
+
if (threadIdx.z == 0) {
|
| 438 |
+
|
| 439 |
+
int64_t byte_offset =
|
| 440 |
+
blockIdx.z * params.workspace_stride + idx_linear * params.workspace_outer_stride;
|
| 441 |
+
|
| 442 |
+
// No conversion - store in compute type
|
| 443 |
+
*reinterpret_cast<ComputeFragment *>(dst_byte_ptr + byte_offset) =
|
| 444 |
+
reinterpret_cast<ComputeFragment const &>(result);
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
// Update indices and pointers
|
| 448 |
+
idx_linear += gridDim.y * blockDim.y;
|
| 449 |
+
|
| 450 |
+
compute_outer_coord_and_offset_(
|
| 451 |
+
params,
|
| 452 |
+
outer_coord,
|
| 453 |
+
dst_byte_offset,
|
| 454 |
+
src_byte_offset,
|
| 455 |
+
idx_linear);
|
| 456 |
+
|
| 457 |
+
} // while (outer index)
|
| 458 |
+
} // if ()
|
| 459 |
+
}
|
| 460 |
+
};
|
| 461 |
+
|
| 462 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 463 |
+
|
| 464 |
+
/// Kernel to perform final reduction
|
| 465 |
+
template <
|
| 466 |
+
int Rank, ///< Rank of source tensor (e.g. NDHWC => 5)
|
| 467 |
+
int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2)
|
| 468 |
+
typename ElementOutput, ///< Data type of output tensor
|
| 469 |
+
typename ElementSource, ///< Data type of source tensor
|
| 470 |
+
typename ReductionOp, ///< Reduction operator
|
| 471 |
+
int VectorLength = 1, ///< Vector length for memory
|
| 472 |
+
typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation
|
| 473 |
+
int Threads = 256, ///< Number of participating threads
|
| 474 |
+
int BatchSize = 4 ///< Number of elements to load per batch
|
| 475 |
+
>
|
| 476 |
+
class TensorReductionAffineStridedFinal {
|
| 477 |
+
public:
|
| 478 |
+
|
| 479 |
+
static int const kRank = Rank;
|
| 480 |
+
static int const kReducedRank = ReducedRank;
|
| 481 |
+
static int const kVectorLength = VectorLength;
|
| 482 |
+
static int const kInnerRank = kRank - kReducedRank;
|
| 483 |
+
static int const kThreads = Threads;
|
| 484 |
+
static int const kBatchSize = BatchSize;
|
| 485 |
+
using ComputeFragment = Array<ElementCompute, VectorLength>;
|
| 486 |
+
using SourceFragment = AlignedArray<ElementSource, VectorLength>;
|
| 487 |
+
using OutputFragment = AlignedArray<ElementOutput, VectorLength>;
|
| 488 |
+
|
| 489 |
+
/// Shared memory
|
| 490 |
+
struct SharedStorage { };
|
| 491 |
+
|
| 492 |
+
/// Parameters structure
|
| 493 |
+
using Params = TensorReductionAffineStridedParams<
|
| 494 |
+
Rank,
|
| 495 |
+
ReducedRank,
|
| 496 |
+
ElementOutput,
|
| 497 |
+
ElementSource,
|
| 498 |
+
ReductionOp,
|
| 499 |
+
VectorLength,
|
| 500 |
+
ElementCompute,
|
| 501 |
+
Threads,
|
| 502 |
+
BatchSize
|
| 503 |
+
>;
|
| 504 |
+
|
| 505 |
+
private:
|
| 506 |
+
|
| 507 |
+
/// Computes the coordinate and offset of a given linear index
|
| 508 |
+
CUTLASS_DEVICE
|
| 509 |
+
void compute_outer_coord_and_offset_(
|
| 510 |
+
Params const ¶ms,
|
| 511 |
+
Coord<kReducedRank - 1> & coord,
|
| 512 |
+
int64_t &dst_offset,
|
| 513 |
+
uint64_t linear_idx) const {
|
| 514 |
+
|
| 515 |
+
// Decompose linear index
|
| 516 |
+
coord = CoordinateDecomposition<kReducedRank - 1>(linear_idx, params.divmod);
|
| 517 |
+
|
| 518 |
+
// Compute tensor offset
|
| 519 |
+
dst_offset = 0;
|
| 520 |
+
|
| 521 |
+
CUTLASS_PRAGMA_UNROLL
|
| 522 |
+
for (int i = 0; i < kReducedRank - 1; ++i) {
|
| 523 |
+
dst_offset += params.dst_stride[i] * coord[i];
|
| 524 |
+
}
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
/// Reduces over the reduction indices
|
| 528 |
+
CUTLASS_DEVICE
|
| 529 |
+
ComputeFragment reduce_indices_(
|
| 530 |
+
Params const ¶ms,
|
| 531 |
+
char *src_byte_ptr) {
|
| 532 |
+
|
| 533 |
+
ReductionOp reduction_op(params.reduction_op);
|
| 534 |
+
|
| 535 |
+
// Accumulated output
|
| 536 |
+
ComputeFragment identity_frag;
|
| 537 |
+
|
| 538 |
+
CUTLASS_PRAGMA_UNROLL
|
| 539 |
+
for (int i = 0; i < int(identity_frag.size()); ++i) {
|
| 540 |
+
identity_frag[i] = params.reduction_identity;
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
ComputeFragment accumulator = identity_frag;
|
| 544 |
+
ComputeFragment workspace_fragments[kBatchSize];
|
| 545 |
+
|
| 546 |
+
// Partially unrolled loop
|
| 547 |
+
for (int idx = 0; idx < params.workspace_count; idx += kBatchSize) {
|
| 548 |
+
|
| 549 |
+
// Issue a batch of loads
|
| 550 |
+
CUTLASS_PRAGMA_UNROLL
|
| 551 |
+
for (int b = 0; b < kBatchSize; ++b) {
|
| 552 |
+
if (idx + b < params.workspace_count) {
|
| 553 |
+
workspace_fragments[b] =
|
| 554 |
+
*reinterpret_cast<ComputeFragment *>(src_byte_ptr);
|
| 555 |
+
}
|
| 556 |
+
else {
|
| 557 |
+
workspace_fragments[b] = identity_frag;
|
| 558 |
+
}
|
| 559 |
+
src_byte_ptr += + params.workspace_stride;
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
// Perform a reduction
|
| 563 |
+
CUTLASS_PRAGMA_UNROLL
|
| 564 |
+
for (int b = 0; b < kBatchSize; ++b) {
|
| 565 |
+
CUTLASS_PRAGMA_UNROLL
|
| 566 |
+
for (int i = 0; i < kVectorLength; ++i) {
|
| 567 |
+
accumulator[i] = reduction_op(accumulator[i], workspace_fragments[b][i]);
|
| 568 |
+
}
|
| 569 |
+
}
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
return accumulator;
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
public:
|
| 576 |
+
|
| 577 |
+
//
|
| 578 |
+
// Methods
|
| 579 |
+
//
|
| 580 |
+
|
| 581 |
+
/// Perform a reduction
|
| 582 |
+
CUTLASS_DEVICE
|
| 583 |
+
void operator()(Params const ¶ms, SharedStorage &shared_storage) {
|
| 584 |
+
|
| 585 |
+
int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength;
|
| 586 |
+
|
| 587 |
+
char * src_byte_ptr = reinterpret_cast<char *>(params.device_workspace + coord_c);
|
| 588 |
+
char * dst_byte_ptr = reinterpret_cast<char *>(params.destination + coord_c);
|
| 589 |
+
|
| 590 |
+
// If the C index is out of bounds, exit
|
| 591 |
+
if (coord_c >= params.extent[kRank - 1]) {
|
| 592 |
+
return;
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
int64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y;
|
| 596 |
+
|
| 597 |
+
// Use modulo division to compute location
|
| 598 |
+
Coord<kReducedRank - 1> outer_coord;
|
| 599 |
+
int64_t dst_byte_offset;
|
| 600 |
+
|
| 601 |
+
compute_outer_coord_and_offset_(
|
| 602 |
+
params,
|
| 603 |
+
outer_coord,
|
| 604 |
+
dst_byte_offset,
|
| 605 |
+
idx_linear);
|
| 606 |
+
|
| 607 |
+
/// Complete the reduction
|
| 608 |
+
while (idx_linear < params.outer_count) {
|
| 609 |
+
|
| 610 |
+
int64_t src_byte_offset = idx_linear * params.workspace_outer_stride;
|
| 611 |
+
|
| 612 |
+
ComputeFragment result = reduce_indices_(
|
| 613 |
+
params,
|
| 614 |
+
src_byte_ptr + src_byte_offset);
|
| 615 |
+
|
| 616 |
+
// Convert to output type and store
|
| 617 |
+
NumericArrayConverter<ElementOutput, ElementCompute, VectorLength> convert_output;
|
| 618 |
+
auto cvt = convert_output(result);
|
| 619 |
+
|
| 620 |
+
*reinterpret_cast<OutputFragment *>(dst_byte_ptr + dst_byte_offset) =
|
| 621 |
+
reinterpret_cast<OutputFragment const &>(cvt);
|
| 622 |
+
|
| 623 |
+
// Update indices and pointers
|
| 624 |
+
idx_linear += gridDim.y * blockDim.y;
|
| 625 |
+
|
| 626 |
+
compute_outer_coord_and_offset_(
|
| 627 |
+
params,
|
| 628 |
+
outer_coord,
|
| 629 |
+
dst_byte_offset,
|
| 630 |
+
idx_linear);
|
| 631 |
+
}
|
| 632 |
+
}
|
| 633 |
+
};
|
| 634 |
+
|
| 635 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 636 |
+
|
| 637 |
+
} // namespace kernel
|
| 638 |
+
} // namespace reduction
|
| 639 |
+
} // namespace cutlass
|
| 640 |
+
|
| 641 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduce.h
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Defines basic thread level reduction with specializations for Array<T, N>.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/numeric_types.h"
|
| 39 |
+
#include "cutlass/array.h"
|
| 40 |
+
#include "cutlass/half.h"
|
| 41 |
+
#include "cutlass/functional.h"
|
| 42 |
+
|
| 43 |
+
namespace cutlass {
|
| 44 |
+
namespace reduction {
|
| 45 |
+
namespace thread {
|
| 46 |
+
|
| 47 |
+
/// Structure to compute the thread level reduction
|
| 48 |
+
template <typename Op, typename T>
|
| 49 |
+
struct Reduce;
|
| 50 |
+
|
| 51 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 52 |
+
|
| 53 |
+
/// Partial Specialization of Reduce for "plus" (a functional operator)
|
| 54 |
+
template <typename T>
|
| 55 |
+
struct Reduce< plus<T>, T > {
|
| 56 |
+
|
| 57 |
+
CUTLASS_HOST_DEVICE
|
| 58 |
+
T operator()(T lhs, T const &rhs) const {
|
| 59 |
+
plus<T> _op;
|
| 60 |
+
return _op(lhs, rhs);
|
| 61 |
+
}
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 65 |
+
|
| 66 |
+
/// Partial specialization of Reduce for Array<T, N>
|
| 67 |
+
template <typename T, int N>
|
| 68 |
+
struct Reduce < plus<T>, Array<T, N>> {
|
| 69 |
+
|
| 70 |
+
CUTLASS_HOST_DEVICE
|
| 71 |
+
Array<T, 1> operator()(Array<T, N> const &in) const {
|
| 72 |
+
|
| 73 |
+
Array<T, 1> result;
|
| 74 |
+
Reduce< plus<T>, T > scalar_reduce;
|
| 75 |
+
result.clear();
|
| 76 |
+
|
| 77 |
+
CUTLASS_PRAGMA_UNROLL
|
| 78 |
+
for (auto i = 0; i < N; ++i) {
|
| 79 |
+
result[0] = scalar_reduce(result[0], in[i]);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
return result;
|
| 83 |
+
}
|
| 84 |
+
};
|
| 85 |
+
|
| 86 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 87 |
+
|
| 88 |
+
/// Partial specializations of Reduce for Array<half_t, N>
|
| 89 |
+
template <int N>
|
| 90 |
+
struct Reduce < plus<half_t>, Array<half_t, N> > {
|
| 91 |
+
|
| 92 |
+
CUTLASS_HOST_DEVICE
|
| 93 |
+
Array<half_t, 1> operator()(Array<half_t, N> const &input) {
|
| 94 |
+
|
| 95 |
+
Array<half_t, 1> result;
|
| 96 |
+
|
| 97 |
+
// If there is only 1 element - there is nothing to reduce
|
| 98 |
+
if( N ==1 ){
|
| 99 |
+
|
| 100 |
+
result[0] = input.front();
|
| 101 |
+
|
| 102 |
+
} else {
|
| 103 |
+
|
| 104 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)
|
| 105 |
+
|
| 106 |
+
__half result_d;
|
| 107 |
+
Array<half_t, 1> const *in_ptr_half = reinterpret_cast<Array<half_t, 1> const *>(&input);
|
| 108 |
+
Array<half_t, 2> const *in_ptr_half2 = reinterpret_cast<Array<half_t, 2> const *>(&input);
|
| 109 |
+
__half2 const *x_in_half2 = reinterpret_cast<__half2 const *>(in_ptr_half2);
|
| 110 |
+
|
| 111 |
+
// Set initial result = first half2, in case N==2
|
| 112 |
+
__half2 tmp_result = x_in_half2[0];
|
| 113 |
+
|
| 114 |
+
CUTLASS_PRAGMA_UNROLL
|
| 115 |
+
for (int i = 1; i < N/2; ++i) {
|
| 116 |
+
|
| 117 |
+
tmp_result = __hadd2(x_in_half2[i], tmp_result);
|
| 118 |
+
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result));
|
| 122 |
+
|
| 123 |
+
// One final step is needed for odd "N" (to add the (N-1)th element)
|
| 124 |
+
if( N%2 ){
|
| 125 |
+
|
| 126 |
+
__half last_element;
|
| 127 |
+
Array<half_t, 1> tmp_last;
|
| 128 |
+
Array<half_t, 1> *tmp_last_ptr = &tmp_last;
|
| 129 |
+
tmp_last_ptr[0] = in_ptr_half[N-1];
|
| 130 |
+
last_element = reinterpret_cast<__half const &>(tmp_last);
|
| 131 |
+
|
| 132 |
+
result_d = __hadd(result_d, last_element);
|
| 133 |
+
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
Array<half_t, 1> *result_ptr = &result;
|
| 137 |
+
*result_ptr = reinterpret_cast<Array<half_t, 1> &>(result_d);
|
| 138 |
+
|
| 139 |
+
#else
|
| 140 |
+
|
| 141 |
+
Reduce< plus<half_t>, half_t > scalar_reduce;
|
| 142 |
+
result.clear();
|
| 143 |
+
|
| 144 |
+
CUTLASS_PRAGMA_UNROLL
|
| 145 |
+
for (auto i = 0; i < N; ++i) {
|
| 146 |
+
|
| 147 |
+
result[0] = scalar_reduce(result[0], input[i]);
|
| 148 |
+
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
#endif
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
return result;
|
| 155 |
+
|
| 156 |
+
}
|
| 157 |
+
};
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 161 |
+
|
| 162 |
+
/// Partial specializations of Reduce for AlignedArray<half_t, N>
|
| 163 |
+
template <int N>
|
| 164 |
+
struct Reduce < plus<half_t>, AlignedArray<half_t, N> > {
|
| 165 |
+
|
| 166 |
+
CUTLASS_HOST_DEVICE
|
| 167 |
+
Array<half_t, 1> operator()(AlignedArray<half_t, N> const &input) {
|
| 168 |
+
|
| 169 |
+
Array<half_t, 1> result;
|
| 170 |
+
|
| 171 |
+
// If there is only 1 element - there is nothing to reduce
|
| 172 |
+
if( N ==1 ){
|
| 173 |
+
|
| 174 |
+
result[0] = input.front();
|
| 175 |
+
|
| 176 |
+
} else {
|
| 177 |
+
|
| 178 |
+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)
|
| 179 |
+
|
| 180 |
+
__half result_d;
|
| 181 |
+
AlignedArray<half_t, 1> const *in_ptr_half = reinterpret_cast<AlignedArray<half_t, 1> const *>(&input);
|
| 182 |
+
AlignedArray<half_t, 2> const *in_ptr_half2 = reinterpret_cast<AlignedArray<half_t, 2> const *>(&input);
|
| 183 |
+
__half2 const *x_in_half2 = reinterpret_cast<__half2 const *>(in_ptr_half2);
|
| 184 |
+
|
| 185 |
+
// Set initial result = first half2, in case N==2
|
| 186 |
+
__half2 tmp_result = x_in_half2[0];
|
| 187 |
+
|
| 188 |
+
CUTLASS_PRAGMA_UNROLL
|
| 189 |
+
for (int i = 1; i < N/2; ++i) {
|
| 190 |
+
|
| 191 |
+
tmp_result = __hadd2(x_in_half2[i], tmp_result);
|
| 192 |
+
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result));
|
| 196 |
+
|
| 197 |
+
// One final step is needed for odd "N" (to add the (N-1)th element)
|
| 198 |
+
if( N%2 ){
|
| 199 |
+
|
| 200 |
+
__half last_element;
|
| 201 |
+
AlignedArray<half_t, 1> tmp_last;
|
| 202 |
+
AlignedArray<half_t, 1> *tmp_last_ptr = &tmp_last;
|
| 203 |
+
tmp_last_ptr[0] = in_ptr_half[N-1];
|
| 204 |
+
last_element = reinterpret_cast<__half const &>(tmp_last);
|
| 205 |
+
|
| 206 |
+
result_d = __hadd(result_d, last_element);
|
| 207 |
+
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
Array<half_t, 1> *result_ptr = &result;
|
| 211 |
+
*result_ptr = reinterpret_cast<Array<half_t, 1> &>(result_d);
|
| 212 |
+
|
| 213 |
+
#else
|
| 214 |
+
|
| 215 |
+
Reduce< plus<half_t>, half_t > scalar_reduce;
|
| 216 |
+
result.clear();
|
| 217 |
+
|
| 218 |
+
CUTLASS_PRAGMA_UNROLL
|
| 219 |
+
for (auto i = 0; i < N; ++i) {
|
| 220 |
+
|
| 221 |
+
result[0] = scalar_reduce(result[0], input[i]);
|
| 222 |
+
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
#endif
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
return result;
|
| 229 |
+
|
| 230 |
+
}
|
| 231 |
+
};
|
| 232 |
+
}
|
| 233 |
+
}
|
| 234 |
+
}
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduction_operators.h
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Kernel performing a reduction over densely packed tensors in global memory
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/tensor_ref.h"
|
| 39 |
+
#include "cutlass/numeric_types.h"
|
| 40 |
+
#include "cutlass/array.h"
|
| 41 |
+
#include "cutlass/functional.h"
|
| 42 |
+
#include "cutlass/numeric_conversion.h"
|
| 43 |
+
|
| 44 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
|
| 46 |
+
namespace cutlass {
|
| 47 |
+
namespace reduction {
|
| 48 |
+
namespace thread {
|
| 49 |
+
|
| 50 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 51 |
+
|
| 52 |
+
/// Mixed-precision reduction
|
| 53 |
+
template <
|
| 54 |
+
typename ElementAccumulator_,
|
| 55 |
+
typename Element_,
|
| 56 |
+
int Count = 1
|
| 57 |
+
>
|
| 58 |
+
struct ReduceAdd {
|
| 59 |
+
|
| 60 |
+
//
|
| 61 |
+
// Type definitions
|
| 62 |
+
//
|
| 63 |
+
|
| 64 |
+
using ElementAccumulator = ElementAccumulator_;
|
| 65 |
+
using Element = Element_;
|
| 66 |
+
static int const kCount = Count;
|
| 67 |
+
|
| 68 |
+
using FragmentAccumulator = cutlass::Array<ElementAccumulator, kCount>;
|
| 69 |
+
using FragmentElement = cutlass::Array<Element, kCount>;
|
| 70 |
+
|
| 71 |
+
struct Params { };
|
| 72 |
+
|
| 73 |
+
//
|
| 74 |
+
// Data members
|
| 75 |
+
//
|
| 76 |
+
|
| 77 |
+
/// Parameters object
|
| 78 |
+
Params params;
|
| 79 |
+
|
| 80 |
+
//
|
| 81 |
+
// Methods
|
| 82 |
+
//
|
| 83 |
+
|
| 84 |
+
/// Constructor
|
| 85 |
+
CUTLASS_HOST_DEVICE
|
| 86 |
+
ReduceAdd(Params params_ = Params()): params(params_) { }
|
| 87 |
+
|
| 88 |
+
/// Operator
|
| 89 |
+
CUTLASS_HOST_DEVICE
|
| 90 |
+
FragmentAccumulator operator()(
|
| 91 |
+
FragmentAccumulator accumulator,
|
| 92 |
+
FragmentElement element) const {
|
| 93 |
+
|
| 94 |
+
plus<FragmentAccumulator> op;
|
| 95 |
+
|
| 96 |
+
NumericArrayConverter<
|
| 97 |
+
ElementAccumulator,
|
| 98 |
+
Element,
|
| 99 |
+
kCount,
|
| 100 |
+
PreferredRoundingMode<ElementAccumulator, Element>::kRound> converter;
|
| 101 |
+
|
| 102 |
+
return op(accumulator, converter(element));
|
| 103 |
+
}
|
| 104 |
+
};
|
| 105 |
+
|
| 106 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 107 |
+
|
| 108 |
+
namespace detail {
|
| 109 |
+
|
| 110 |
+
/// Special handling for binary operators
|
| 111 |
+
template <typename ReductionOp, typename Element, int N>
|
| 112 |
+
struct VectorizeArrayOperation {
|
| 113 |
+
|
| 114 |
+
using ValueType = Array<Element, N>;
|
| 115 |
+
|
| 116 |
+
CUTLASS_HOST_DEVICE
|
| 117 |
+
ValueType operator()(
|
| 118 |
+
ReductionOp const &reduction_op,
|
| 119 |
+
ValueType const &lhs,
|
| 120 |
+
ValueType const &rhs) const {
|
| 121 |
+
|
| 122 |
+
ValueType result;
|
| 123 |
+
|
| 124 |
+
CUTLASS_PRAGMA_UNROLL
|
| 125 |
+
for (int i = 0; i < N; ++i) {
|
| 126 |
+
result[i] = reduction_op(lhs[i], rhs[i]);
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
return result;
|
| 130 |
+
}
|
| 131 |
+
};
|
| 132 |
+
|
| 133 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 134 |
+
|
| 135 |
+
template <typename ReductionOp, typename Element, int N>
|
| 136 |
+
struct ReduceArrayOperation {
|
| 137 |
+
|
| 138 |
+
using ArrayType = Array<Element, N>;
|
| 139 |
+
|
| 140 |
+
CUTLASS_HOST_DEVICE
|
| 141 |
+
Element operator()(
|
| 142 |
+
ReductionOp const &reduction_op,
|
| 143 |
+
ArrayType const &array) const {
|
| 144 |
+
|
| 145 |
+
Element item = reduction_op(array[0], array[1]);
|
| 146 |
+
|
| 147 |
+
CUTLASS_PRAGMA_UNROLL
|
| 148 |
+
for (int i = 2; i < N; ++i) {
|
| 149 |
+
item = reduction_op(item, array[i]);
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
return item;
|
| 153 |
+
}
|
| 154 |
+
};
|
| 155 |
+
|
| 156 |
+
template <int N>
|
| 157 |
+
struct ReduceArrayOperation<logical_and<uint1b_t>, uint1b_t, N> {
|
| 158 |
+
|
| 159 |
+
using ArrayType = Array<uint1b_t, N>;
|
| 160 |
+
|
| 161 |
+
CUTLASS_HOST_DEVICE
|
| 162 |
+
uint1b_t operator()(
|
| 163 |
+
logical_and<uint1b_t> const &reduction_op,
|
| 164 |
+
ArrayType const &array) const {
|
| 165 |
+
|
| 166 |
+
uint8_t const *ptr = reinterpret_cast<uint8_t const *>(&array);
|
| 167 |
+
bool item = false;
|
| 168 |
+
|
| 169 |
+
CUTLASS_PRAGMA_UNROLL
|
| 170 |
+
for (int byte = 0; byte < (N + 7) / 8; ++byte) {
|
| 171 |
+
uint8_t bits = ptr[byte];
|
| 172 |
+
item = (item || !bits);
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
return uint1b_t{!item};
|
| 176 |
+
}
|
| 177 |
+
};
|
| 178 |
+
|
| 179 |
+
template <int N>
|
| 180 |
+
struct ReduceArrayOperation<logical_or<uint1b_t>, uint1b_t, N> {
|
| 181 |
+
|
| 182 |
+
using ArrayType = Array<uint1b_t, N>;
|
| 183 |
+
|
| 184 |
+
CUTLASS_HOST_DEVICE
|
| 185 |
+
uint1b_t operator()(
|
| 186 |
+
logical_and<uint1b_t> const &reduction_op,
|
| 187 |
+
ArrayType const &array) const {
|
| 188 |
+
|
| 189 |
+
uint8_t const *ptr = reinterpret_cast<uint8_t const *>(&array);
|
| 190 |
+
bool item = true;
|
| 191 |
+
|
| 192 |
+
CUTLASS_PRAGMA_UNROLL
|
| 193 |
+
for (int byte = 0; byte < (N + 7) / 8; ++byte) {
|
| 194 |
+
uint8_t bits = ptr[byte];
|
| 195 |
+
item = (item || bits);
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
return uint1b_t{item};
|
| 199 |
+
}
|
| 200 |
+
};
|
| 201 |
+
|
| 202 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 203 |
+
|
| 204 |
+
/// Helper function to infer template argument types
|
| 205 |
+
template <typename ReductionOp, typename Element, int N>
|
| 206 |
+
CUTLASS_HOST_DEVICE
|
| 207 |
+
Array<Element, N> ApplyArrayOperator(
|
| 208 |
+
ReductionOp const &reduction_op,
|
| 209 |
+
Array<Element, N> const &lhs,
|
| 210 |
+
Array<Element, N> const &rhs) {
|
| 211 |
+
|
| 212 |
+
VectorizeArrayOperation<ReductionOp, Element, N> vectorize_op;
|
| 213 |
+
|
| 214 |
+
return vectorize_op(reduction_op, lhs, rhs);
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
/// Helper to reduce an array
|
| 218 |
+
template <typename ReductionOp, typename Element, int N>
|
| 219 |
+
Element ReduceArray(ReductionOp const &reduction_op, Array<Element, N> const &array) {
|
| 220 |
+
ReduceArrayOperation<ReductionOp, Element, N> reduce_array_op;
|
| 221 |
+
|
| 222 |
+
return reduce_array_op(reduction_op, array);
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 226 |
+
|
| 227 |
+
} // namespace detail
|
| 228 |
+
|
| 229 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 230 |
+
|
| 231 |
+
} // namespace thread
|
| 232 |
+
} // namespace reduction
|
| 233 |
+
} // namespace cutlass
|
| 234 |
+
|
| 235 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/reduction/threadblock_swizzle.h
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Defies functors for mapping blockIdx to partitions of the batched reduction computation.
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
#include "cutlass/coord.h"
|
| 36 |
+
|
| 37 |
+
namespace cutlass {
|
| 38 |
+
namespace reduction {
|
| 39 |
+
struct DefaultBlockSwizzle {
|
| 40 |
+
/// Ctor
|
| 41 |
+
CUTLASS_HOST_DEVICE DefaultBlockSwizzle() {}
|
| 42 |
+
|
| 43 |
+
/// Swizzle the block index.
|
| 44 |
+
CUTLASS_DEVICE dim3 swizzle() { return blockIdx; }
|
| 45 |
+
|
| 46 |
+
///
|
| 47 |
+
CUTLASS_HOST_DEVICE dim3 get_grid_layout(Coord<3> const &problem_size,
|
| 48 |
+
Coord<3> const &OutputTile) {
|
| 49 |
+
assert(OutputTile[0] == 1 && OutputTile[1] == 1);
|
| 50 |
+
assert((problem_size[0] * problem_size[1] * problem_size[2]) % OutputTile[2] == 0);
|
| 51 |
+
dim3 grid;
|
| 52 |
+
grid.x = problem_size[0] * problem_size[1] * problem_size[2]
|
| 53 |
+
/ OutputTile[2] ;
|
| 54 |
+
return grid;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
///
|
| 58 |
+
CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &SubTile) {
|
| 59 |
+
assert(SubTile[0] == 1 && SubTile[1] == 1);
|
| 60 |
+
dim3 block = swizzle();
|
| 61 |
+
Coord<3> threadblock_offset =
|
| 62 |
+
make_Coord(0, 0, block.x * SubTile[2]);
|
| 63 |
+
return threadblock_offset;
|
| 64 |
+
}
|
| 65 |
+
};
|
| 66 |
+
} // namespace reduction
|
| 67 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/relatively_equal.h
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/* \file
|
| 32 |
+
\brief Performs comparison between two elements with support for floating-point comparisons.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "numeric_types.h"
|
| 38 |
+
#include "complex.h"
|
| 39 |
+
|
| 40 |
+
namespace cutlass {
|
| 41 |
+
|
| 42 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 43 |
+
|
| 44 |
+
template <typename T, typename U = T>
|
| 45 |
+
CUTLASS_HOST_DEVICE
|
| 46 |
+
bool relatively_equal(T a, T b, U epsilon, U nonzero_floor);
|
| 47 |
+
|
| 48 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
namespace detail {
|
| 51 |
+
|
| 52 |
+
// This floating-point comparison function implements the method described in
|
| 53 |
+
//
|
| 54 |
+
// https://floating-point-gui.de/errors/comparison/
|
| 55 |
+
//
|
| 56 |
+
template <typename T>
|
| 57 |
+
CUTLASS_HOST_DEVICE
|
| 58 |
+
bool relatively_equal_float(T a, T b, T epsilon, T nonzero_floor) {
|
| 59 |
+
|
| 60 |
+
#if defined(__CUDACC_RTC__)
|
| 61 |
+
using cuda::std::abs;
|
| 62 |
+
#else
|
| 63 |
+
using std::abs;
|
| 64 |
+
#endif
|
| 65 |
+
|
| 66 |
+
T abs_A = abs(a);
|
| 67 |
+
T abs_B = abs(b);
|
| 68 |
+
T diff = abs(a - b);
|
| 69 |
+
T zero = T(0);
|
| 70 |
+
|
| 71 |
+
if (a == b) {
|
| 72 |
+
return true;
|
| 73 |
+
}
|
| 74 |
+
else if (a == zero || b == zero || (abs_A + abs_B) < nonzero_floor) {
|
| 75 |
+
return diff < epsilon * nonzero_floor;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
return diff < epsilon * (abs_A + abs_B);
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
} // namespace detail
|
| 82 |
+
|
| 83 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 84 |
+
|
| 85 |
+
template <>
|
| 86 |
+
CUTLASS_HOST_DEVICE
|
| 87 |
+
bool relatively_equal<bool>(bool a, bool b, bool, bool) {
|
| 88 |
+
return (a == b);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
template <>
|
| 92 |
+
CUTLASS_HOST_DEVICE
|
| 93 |
+
bool relatively_equal<uint1b_t>(uint1b_t a, uint1b_t b, uint1b_t, uint1b_t) {
|
| 94 |
+
return (a == b);
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
template <>
|
| 98 |
+
CUTLASS_HOST_DEVICE
|
| 99 |
+
bool relatively_equal<int2b_t>(int2b_t a, int2b_t b, int2b_t, int2b_t) {
|
| 100 |
+
return (a == b);
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
template <>
|
| 104 |
+
CUTLASS_HOST_DEVICE
|
| 105 |
+
bool relatively_equal<uint2b_t>(uint2b_t a, uint2b_t b, uint2b_t, uint2b_t) {
|
| 106 |
+
return (a == b);
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
template <>
|
| 110 |
+
CUTLASS_HOST_DEVICE
|
| 111 |
+
bool relatively_equal<int4b_t>(int4b_t a, int4b_t b, int4b_t, int4b_t) {
|
| 112 |
+
return (a == b);
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
template <>
|
| 116 |
+
CUTLASS_HOST_DEVICE
|
| 117 |
+
bool relatively_equal<uint4b_t>(uint4b_t a, uint4b_t b, uint4b_t, uint4b_t) {
|
| 118 |
+
return (a == b);
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
template <>
|
| 122 |
+
CUTLASS_HOST_DEVICE
|
| 123 |
+
bool relatively_equal<int8_t>(int8_t a, int8_t b, int8_t, int8_t) {
|
| 124 |
+
return (a == b);
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
template <>
|
| 128 |
+
CUTLASS_HOST_DEVICE
|
| 129 |
+
bool relatively_equal<uint8_t>(uint8_t a, uint8_t b, uint8_t, uint8_t) {
|
| 130 |
+
return (a == b);
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
template <>
|
| 134 |
+
CUTLASS_HOST_DEVICE
|
| 135 |
+
bool relatively_equal<int16_t>(int16_t a, int16_t b, int16_t, int16_t) {
|
| 136 |
+
return (a == b);
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
template <>
|
| 140 |
+
CUTLASS_HOST_DEVICE
|
| 141 |
+
bool relatively_equal<uint16_t>(uint16_t a, uint16_t b, uint16_t, uint16_t) {
|
| 142 |
+
return (a == b);
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
template <>
|
| 146 |
+
CUTLASS_HOST_DEVICE
|
| 147 |
+
bool relatively_equal<int32_t>(int32_t a, int32_t b, int32_t, int32_t) {
|
| 148 |
+
return (a == b);
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
template <>
|
| 152 |
+
CUTLASS_HOST_DEVICE
|
| 153 |
+
bool relatively_equal<uint32_t>(uint32_t a, uint32_t b, uint32_t, uint32_t) {
|
| 154 |
+
return (a == b);
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
template <>
|
| 158 |
+
CUTLASS_HOST_DEVICE
|
| 159 |
+
bool relatively_equal<int64_t>(int64_t a, int64_t b, int64_t, int64_t) {
|
| 160 |
+
return (a == b);
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
template <>
|
| 164 |
+
CUTLASS_HOST_DEVICE
|
| 165 |
+
bool relatively_equal<uint64_t>(uint64_t a, uint64_t b, uint64_t, uint64_t) {
|
| 166 |
+
return (a == b);
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 170 |
+
|
| 171 |
+
template <>
|
| 172 |
+
CUTLASS_HOST_DEVICE
|
| 173 |
+
bool relatively_equal<float_e4m3_t>(float_e4m3_t a, float_e4m3_t b, float_e4m3_t epsilon, float_e4m3_t nonzero_floor) {
|
| 174 |
+
return detail::relatively_equal_float<float>(a, b, epsilon, nonzero_floor);
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
template <>
|
| 178 |
+
CUTLASS_HOST_DEVICE
|
| 179 |
+
bool relatively_equal<float_e5m2_t>(float_e5m2_t a, float_e5m2_t b, float_e5m2_t epsilon, float_e5m2_t nonzero_floor) {
|
| 180 |
+
return detail::relatively_equal_float<float>(a, b, epsilon, nonzero_floor);
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
template <>
|
| 184 |
+
CUTLASS_HOST_DEVICE
|
| 185 |
+
bool relatively_equal<half_t>(half_t a, half_t b, half_t epsilon, half_t nonzero_floor) {
|
| 186 |
+
return detail::relatively_equal_float(a, b, epsilon, nonzero_floor);
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
template <>
|
| 190 |
+
CUTLASS_HOST_DEVICE
|
| 191 |
+
bool relatively_equal<bfloat16_t>(
|
| 192 |
+
bfloat16_t a,
|
| 193 |
+
bfloat16_t b,
|
| 194 |
+
bfloat16_t epsilon,
|
| 195 |
+
bfloat16_t nonzero_floor) {
|
| 196 |
+
|
| 197 |
+
return detail::relatively_equal_float(a, b, epsilon, nonzero_floor);
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
template <>
|
| 201 |
+
CUTLASS_HOST_DEVICE
|
| 202 |
+
bool relatively_equal<tfloat32_t>(
|
| 203 |
+
tfloat32_t a,
|
| 204 |
+
tfloat32_t b,
|
| 205 |
+
tfloat32_t epsilon,
|
| 206 |
+
tfloat32_t nonzero_floor) {
|
| 207 |
+
|
| 208 |
+
return detail::relatively_equal_float(a, b, epsilon, nonzero_floor);
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
template <>
|
| 212 |
+
CUTLASS_HOST_DEVICE
|
| 213 |
+
bool relatively_equal<float>(float a, float b, float epsilon, float nonzero_floor) {
|
| 214 |
+
return detail::relatively_equal_float(a, b, epsilon, nonzero_floor);
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
template <>
|
| 219 |
+
CUTLASS_HOST_DEVICE
|
| 220 |
+
bool relatively_equal<double>(double a, double b, double epsilon, double nonzero_floor) {
|
| 221 |
+
return detail::relatively_equal_float(a, b, epsilon, nonzero_floor);
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
template<typename T>
|
| 225 |
+
CUTLASS_HOST_DEVICE
|
| 226 |
+
bool relatively_equal(complex<T> a, complex<T> b, T epsilon, T nonzero_floor) {
|
| 227 |
+
#if defined(__CUDACC_RTC__)
|
| 228 |
+
using cuda::std::abs;
|
| 229 |
+
#else
|
| 230 |
+
using std::abs;
|
| 231 |
+
#endif
|
| 232 |
+
|
| 233 |
+
T abs_A = abs(a);
|
| 234 |
+
T abs_B = abs(b);
|
| 235 |
+
T diff = abs(a - b);
|
| 236 |
+
complex<T> zero = complex<T>{T{}, T{}};
|
| 237 |
+
|
| 238 |
+
if (a == b) {
|
| 239 |
+
return true;
|
| 240 |
+
}
|
| 241 |
+
else if (a == zero || b == zero || diff < nonzero_floor) {
|
| 242 |
+
return diff < epsilon * nonzero_floor;
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
return diff < epsilon * (abs_A + abs_B);
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
template <typename T>
|
| 249 |
+
CUTLASS_HOST_DEVICE
|
| 250 |
+
bool relatively_equal(complex<T> a, complex<T> b, complex<T> epsilon, complex<T> nonzero_floor) {
|
| 251 |
+
#if defined(__CUDACC_RTC__)
|
| 252 |
+
using cuda::std::abs;
|
| 253 |
+
#else
|
| 254 |
+
using std::abs;
|
| 255 |
+
#endif
|
| 256 |
+
|
| 257 |
+
T abs_A = abs(a);
|
| 258 |
+
T abs_B = abs(b);
|
| 259 |
+
complex<T> diff = a - b;
|
| 260 |
+
T abs_diff = abs(diff);
|
| 261 |
+
complex<T> zero = complex<T>{T{}, T{}};
|
| 262 |
+
|
| 263 |
+
if (a == b) {
|
| 264 |
+
return true;
|
| 265 |
+
}
|
| 266 |
+
else if (a == zero || b == zero || abs_diff < abs(nonzero_floor)) {
|
| 267 |
+
return abs_diff < abs(epsilon * nonzero_floor);
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
return abs_diff < abs(epsilon) * (abs_A + abs_B);
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
template <>
|
| 275 |
+
CUTLASS_HOST_DEVICE
|
| 276 |
+
bool relatively_equal<float_e2m3_t>(float_e2m3_t a, float_e2m3_t b, float_e2m3_t epsilon, float_e2m3_t nonzero_floor) {
|
| 277 |
+
return detail::relatively_equal_float<float>(a, b, epsilon, nonzero_floor);
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
template <>
|
| 281 |
+
CUTLASS_HOST_DEVICE
|
| 282 |
+
bool relatively_equal<float_e3m2_t>(float_e3m2_t a, float_e3m2_t b, float_e3m2_t epsilon, float_e3m2_t nonzero_floor) {
|
| 283 |
+
return detail::relatively_equal_float<float>(a, b, epsilon, nonzero_floor);
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
template <>
|
| 287 |
+
CUTLASS_HOST_DEVICE
|
| 288 |
+
bool relatively_equal<float_e2m1_t>(float_e2m1_t a, float_e2m1_t b, float_e2m1_t epsilon, float_e2m1_t nonzero_floor) {
|
| 289 |
+
return detail::relatively_equal_float<float>(a, b, epsilon, nonzero_floor);
|
| 290 |
+
}
|
| 291 |
+
template <>
|
| 292 |
+
CUTLASS_HOST_DEVICE
|
| 293 |
+
bool relatively_equal<float_ue8m0_t>(float_ue8m0_t a, float_ue8m0_t b, float_ue8m0_t epsilon, float_ue8m0_t nonzero_floor) {
|
| 294 |
+
return detail::relatively_equal_float<float>(a, b, epsilon, nonzero_floor);
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
template <>
|
| 298 |
+
CUTLASS_HOST_DEVICE
|
| 299 |
+
bool relatively_equal<float_ue4m3_t>(float_ue4m3_t a, float_ue4m3_t b, float_ue4m3_t epsilon, float_ue4m3_t nonzero_floor) {
|
| 300 |
+
return detail::relatively_equal_float<float>(a, b, epsilon, nonzero_floor);
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 304 |
+
|
| 305 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/semaphore.h
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Implementation of a CTA-wide semaphore for inter-CTA synchronization.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
|
| 39 |
+
#include "cutlass/array.h"
|
| 40 |
+
|
| 41 |
+
#include "cutlass/numeric_types.h"
|
| 42 |
+
#include "cutlass/matrix_shape.h"
|
| 43 |
+
|
| 44 |
+
#include "cutlass/gemm/gemm.h"
|
| 45 |
+
|
| 46 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
|
| 50 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 51 |
+
|
| 52 |
+
/// CTA-wide semaphore for inter-CTA synchronization.
|
| 53 |
+
class Semaphore {
|
| 54 |
+
public:
|
| 55 |
+
|
| 56 |
+
int *lock;
|
| 57 |
+
bool wait_thread;
|
| 58 |
+
int state;
|
| 59 |
+
|
| 60 |
+
public:
|
| 61 |
+
|
| 62 |
+
/// Implements a semaphore to wait for a flag to reach a given value
|
| 63 |
+
CUTLASS_HOST_DEVICE
|
| 64 |
+
Semaphore(int *lock_, int thread_id):
|
| 65 |
+
lock(lock_),
|
| 66 |
+
wait_thread(thread_id < 0 || thread_id == 0),
|
| 67 |
+
state(-1) {
|
| 68 |
+
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
/// Permit fetching the synchronization mechanism early
|
| 72 |
+
CUTLASS_DEVICE
|
| 73 |
+
void fetch() {
|
| 74 |
+
if (wait_thread) {
|
| 75 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
| 76 |
+
asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
|
| 77 |
+
#else
|
| 78 |
+
asm volatile ("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
|
| 79 |
+
#endif
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
/// Gets the internal state
|
| 84 |
+
CUTLASS_DEVICE
|
| 85 |
+
int get_state() const {
|
| 86 |
+
return state;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
/// Waits until the semaphore is equal to the given value
|
| 90 |
+
CUTLASS_DEVICE
|
| 91 |
+
void wait(int status = 0) {
|
| 92 |
+
while( __syncthreads_and(state != status) ) {
|
| 93 |
+
fetch();
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
__syncthreads();
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
/// Updates the lock with the given result
|
| 100 |
+
CUTLASS_DEVICE
|
| 101 |
+
void release(int status = 0) {
|
| 102 |
+
__syncthreads();
|
| 103 |
+
|
| 104 |
+
if (wait_thread) {
|
| 105 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
| 106 |
+
asm volatile ("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
|
| 107 |
+
#else
|
| 108 |
+
asm volatile ("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
|
| 109 |
+
#endif
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
};
|
| 113 |
+
|
| 114 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 115 |
+
|
| 116 |
+
} // namespace cutlass
|
| 117 |
+
|
| 118 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/subbyte_reference.h
ADDED
|
@@ -0,0 +1,1388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Provides a mechanism for packing and unpacking elements smaller than one byte
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
#include "cutlass/cutlass.h"
|
| 37 |
+
#include "cutlass/integer_subbyte.h"
|
| 38 |
+
#include "cutlass/fast_math.h"
|
| 39 |
+
|
| 40 |
+
namespace cutlass {
|
| 41 |
+
|
| 42 |
+
namespace detail {
|
| 43 |
+
// This is an implementation detail of cutlass::SubbyteReference and.
|
| 44 |
+
// cutlass::HostTensor. For a given logical element type Element,
|
| 45 |
+
// and its corresponding storage (physical) element type StorageUnit,
|
| 46 |
+
// it computes quantities that help with managing allocations.
|
| 47 |
+
//
|
| 48 |
+
// CUTLASS uses a hidden "ContainerUnitType" or StorageUnit type to support
|
| 49 |
+
// packed arrays of subbyte types such as int4. Element is the "logical" type
|
| 50 |
+
// for computations, while CUTLASS uses StorageUnit as the element type
|
| 51 |
+
// of a packed array of Element. If Element is not a subbyte type,
|
| 52 |
+
// then the corresponding StorageUnit type is just Element itself.
|
| 53 |
+
//
|
| 54 |
+
// The ContainerType is always calculated as an array StorageUnit type (the StorageUnit
|
| 55 |
+
// is always a byte for subbyte types),
|
| 56 |
+
// and its number of bits is the lcm of the subbyte type's number of bits and 8.
|
| 57 |
+
// Below are some examples for different subbyte types.
|
| 58 |
+
//
|
| 59 |
+
// * Subbyte Type=int2, ContainerType=StorageUnit[1] (StorageUnit=uint8_t)
|
| 60 |
+
// * Subbyte Type=int4, ContainerType=StorageUnit[1] (StorageUnit=uint8_t)
|
| 61 |
+
template<class Element, class StorageUnit>
|
| 62 |
+
struct StorageContainerCalculator {
|
| 63 |
+
// kContainerTypeNumBits: The number of bits needed for ContainerType
|
| 64 |
+
static constexpr int kContainerTypeNumBits = (sizeof_bits<Element>::value < 8) ? cutlass::lcm_cxx11(sizeof_bits<Element>::value, sizeof_bits<StorageUnit>::value) : sizeof_bits<Element>::value;
|
| 65 |
+
static_assert(kContainerTypeNumBits % sizeof_bits<Element>::value == 0, "The bits of ContainerType should be divisible by the element's number of bits");
|
| 66 |
+
// kContainerTypeNumLogicalElements: The number of logical Element instance(s) that can be stored per ContainerType instance
|
| 67 |
+
static constexpr int kContainerTypeNumLogicalElements = kContainerTypeNumBits / sizeof_bits<Element>::value;
|
| 68 |
+
/// 3. kContainerTypeNumBytes: The number of bytes per ContainerType instance
|
| 69 |
+
static constexpr int kContainerTypeNumBytes = kContainerTypeNumBits / 8;
|
| 70 |
+
/// 4. kContainerTypeNumBytes: The number of base StorageUnit in the ContainerType
|
| 71 |
+
static constexpr int kContainerTypeNumStorageUnit = kContainerTypeNumBits / sizeof_bits<StorageUnit>::value;
|
| 72 |
+
|
| 73 |
+
static_assert(kContainerTypeNumBits != 0, "kContainerTypeNumBits can not be zero");
|
| 74 |
+
static_assert(kContainerTypeNumLogicalElements != 0, "kContainerTypeNumLogicalElements can not be zero");
|
| 75 |
+
static_assert(kContainerTypeNumBytes != 0, "kContainerTypeNumBytes can not be zero");
|
| 76 |
+
};
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 80 |
+
|
| 81 |
+
/// This class provides a mechanism for packing and unpacking elements smaller than one byte. It
|
| 82 |
+
/// assumes these sub-byte elements are packed in a traditional C++ numeric type.
|
| 83 |
+
///
|
| 84 |
+
/// The intended application is to provide a mechanism to indirectly reference elements in
|
| 85 |
+
/// memory or Array<> objects whose addresses cannot otherwise be taken since they are smaller
|
| 86 |
+
/// than one byte.
|
| 87 |
+
///
|
| 88 |
+
/// Supports basic pointer arithmetic:
|
| 89 |
+
///
|
| 90 |
+
/// Example:
|
| 91 |
+
///
|
| 92 |
+
/// int4b_t *ptr = ...;
|
| 93 |
+
///
|
| 94 |
+
/// SubbyteReference<int4b_t> ref = ptr;
|
| 95 |
+
/// ref += 15;
|
| 96 |
+
///
|
| 97 |
+
/// int4b_t x = ref; // load an int4b_t
|
| 98 |
+
/// ref = x + 2_s4; // perform arithmetic on int4b_t and then store
|
| 99 |
+
///
|
| 100 |
+
template <
|
| 101 |
+
typename Element_, /// CUTLASS numeric element type.
|
| 102 |
+
typename Storage_ = uint8_t, /// Underlying storage type. Must be able to hold an integer
|
| 103 |
+
/// number of objects of type Element.
|
| 104 |
+
class = void
|
| 105 |
+
>
|
| 106 |
+
class ConstSubbyteReference {
|
| 107 |
+
public:
|
| 108 |
+
|
| 109 |
+
using Element = Element_;
|
| 110 |
+
using Storage = Storage_;
|
| 111 |
+
using StoragePointer = Storage const *;
|
| 112 |
+
|
| 113 |
+
static_assert(sizeof_bits<Element>::value <= sizeof_bits<Storage>::value,
|
| 114 |
+
"Size of Element must not be greater than Storage.");
|
| 115 |
+
|
| 116 |
+
static_assert(!(sizeof_bits<Storage>::value % sizeof_bits<Element>::value),
|
| 117 |
+
"Storage must be divisible by Element");
|
| 118 |
+
|
| 119 |
+
private:
|
| 120 |
+
|
| 121 |
+
///! Number of elements per storage vector
|
| 122 |
+
int const kElementsPerVector = sizeof_bits<Storage>::value / sizeof_bits<Element>::value;
|
| 123 |
+
|
| 124 |
+
///! Bit mask
|
| 125 |
+
Storage const kMask =
|
| 126 |
+
((sizeof_bits<Element>::value < sizeof_bits<Storage>::value) ?
|
| 127 |
+
(Storage(1) << sizeof_bits<Element>::value) - Storage(1) :
|
| 128 |
+
~Storage(0));
|
| 129 |
+
|
| 130 |
+
private:
|
| 131 |
+
|
| 132 |
+
/// Pointer to array containing element
|
| 133 |
+
StoragePointer ptr_;
|
| 134 |
+
|
| 135 |
+
/// Offset (in units of elements) from pointer.
|
| 136 |
+
///
|
| 137 |
+
/// Invariant: must always be in range [0, kElementsPerVector)
|
| 138 |
+
int offset_;
|
| 139 |
+
|
| 140 |
+
public:
|
| 141 |
+
|
| 142 |
+
CUTLASS_HOST_DEVICE
|
| 143 |
+
ConstSubbyteReference(): ptr_(nullptr), offset_(0) { }
|
| 144 |
+
|
| 145 |
+
/// Constructor
|
| 146 |
+
CUTLASS_HOST_DEVICE
|
| 147 |
+
ConstSubbyteReference(
|
| 148 |
+
Element const *ptr, /// pointer to memory
|
| 149 |
+
int64_t offset /// logical offset in units of Element
|
| 150 |
+
):
|
| 151 |
+
ptr_(reinterpret_cast<StoragePointer>(ptr)),
|
| 152 |
+
offset_(0) {
|
| 153 |
+
|
| 154 |
+
int64_t offset_in_vectors = offset / kElementsPerVector;
|
| 155 |
+
int64_t offset_in_elements = offset % kElementsPerVector;
|
| 156 |
+
|
| 157 |
+
ptr_ += offset_in_vectors;
|
| 158 |
+
offset_ = int(offset_in_elements);
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
/// Constructor
|
| 162 |
+
CUTLASS_HOST_DEVICE
|
| 163 |
+
ConstSubbyteReference(
|
| 164 |
+
Element *ptr = nullptr
|
| 165 |
+
): ConstSubbyteReference(ptr, 0) { }
|
| 166 |
+
|
| 167 |
+
/// Gets storage pointer
|
| 168 |
+
CUTLASS_HOST_DEVICE
|
| 169 |
+
StoragePointer storage_pointer() const {
|
| 170 |
+
return ptr_;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
/// Gets element offset within storage vector
|
| 174 |
+
CUTLASS_HOST_DEVICE
|
| 175 |
+
int element_offset() const {
|
| 176 |
+
return offset_;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
/// Unpacks an element from memory
|
| 180 |
+
CUTLASS_HOST_DEVICE
|
| 181 |
+
Element get() const {
|
| 182 |
+
Storage item = Storage((*ptr_ >> (offset_ * sizeof_bits<Element>::value)) & kMask);
|
| 183 |
+
return reinterpret_cast<Element const &>(item);
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
/// Unpacks an element from memory
|
| 187 |
+
CUTLASS_HOST_DEVICE
|
| 188 |
+
operator Element() const {
|
| 189 |
+
return get();
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
/// Adds an offset in units of elements to the reference
|
| 193 |
+
CUTLASS_HOST_DEVICE
|
| 194 |
+
ConstSubbyteReference &operator+=(int offset) {
|
| 195 |
+
|
| 196 |
+
offset += offset_;
|
| 197 |
+
|
| 198 |
+
int offset_in_vectors = offset / kElementsPerVector;
|
| 199 |
+
int offset_in_elements = offset % kElementsPerVector;
|
| 200 |
+
|
| 201 |
+
ptr_ += offset_in_vectors;
|
| 202 |
+
offset_ = offset_in_elements;
|
| 203 |
+
|
| 204 |
+
return *this;
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
/// Adds an offset in units of elements to the reference
|
| 208 |
+
CUTLASS_HOST_DEVICE
|
| 209 |
+
ConstSubbyteReference &operator+=(long long offset) {
|
| 210 |
+
|
| 211 |
+
offset += offset_;
|
| 212 |
+
|
| 213 |
+
long long offset_in_vectors = offset / kElementsPerVector;
|
| 214 |
+
int offset_in_elements = int(offset % kElementsPerVector);
|
| 215 |
+
|
| 216 |
+
ptr_ += offset_in_vectors;
|
| 217 |
+
offset_ = offset_in_elements;
|
| 218 |
+
|
| 219 |
+
return *this;
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
/// Adds an offset in units of elements to the reference
|
| 223 |
+
CUTLASS_HOST_DEVICE
|
| 224 |
+
ConstSubbyteReference &operator-=(int offset) {
|
| 225 |
+
|
| 226 |
+
int offset_in_vectors = offset / kElementsPerVector;
|
| 227 |
+
int offset_in_elements = offset % kElementsPerVector;
|
| 228 |
+
|
| 229 |
+
ptr_ -= offset_in_vectors;
|
| 230 |
+
offset_ -= offset_in_elements;
|
| 231 |
+
|
| 232 |
+
if (offset_ < 0) {
|
| 233 |
+
offset_ += kElementsPerVector;
|
| 234 |
+
--ptr_;
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
return *this;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
/// Adds an offset in units of elements to the reference
|
| 241 |
+
CUTLASS_HOST_DEVICE
|
| 242 |
+
ConstSubbyteReference &operator-=(long long offset) {
|
| 243 |
+
|
| 244 |
+
long long offset_in_vectors = offset / kElementsPerVector;
|
| 245 |
+
int offset_in_elements = int(offset % kElementsPerVector);
|
| 246 |
+
|
| 247 |
+
ptr_ -= offset_in_vectors;
|
| 248 |
+
offset_ -= offset_in_elements;
|
| 249 |
+
|
| 250 |
+
if (offset_ < 0) {
|
| 251 |
+
offset_ += kElementsPerVector;
|
| 252 |
+
--ptr_;
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
return *this;
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
/// Returns a reference to an element with a given offset from the current reference
|
| 259 |
+
CUTLASS_HOST_DEVICE
|
| 260 |
+
ConstSubbyteReference operator+(int offset) const {
|
| 261 |
+
|
| 262 |
+
ConstSubbyteReference ref(ptr_, offset_);
|
| 263 |
+
ref += offset;
|
| 264 |
+
|
| 265 |
+
return ref;
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
/// Returns a reference to an element with a given offset from the current reference
|
| 269 |
+
CUTLASS_HOST_DEVICE
|
| 270 |
+
ConstSubbyteReference operator+(long long offset) const {
|
| 271 |
+
|
| 272 |
+
ConstSubbyteReference ref(ptr_, offset_);
|
| 273 |
+
ref += offset;
|
| 274 |
+
|
| 275 |
+
return ref;
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
/// Returns a reference to an element with a given offset from the current reference
|
| 279 |
+
CUTLASS_HOST_DEVICE
|
| 280 |
+
ConstSubbyteReference operator-(int offset) const {
|
| 281 |
+
|
| 282 |
+
ConstSubbyteReference ref(ptr_, offset_);
|
| 283 |
+
ref -= offset;
|
| 284 |
+
|
| 285 |
+
return ref;
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
/// Returns a reference to an element with a given offset from the current reference
|
| 289 |
+
CUTLASS_HOST_DEVICE
|
| 290 |
+
ConstSubbyteReference operator-=(long long offset) const {
|
| 291 |
+
|
| 292 |
+
ConstSubbyteReference ref(ptr_, offset_);
|
| 293 |
+
ref -= offset;
|
| 294 |
+
|
| 295 |
+
return ref;
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
/// Computes the difference in elements between references
|
| 299 |
+
CUTLASS_HOST_DEVICE
|
| 300 |
+
ptrdiff_t operator-(ConstSubbyteReference ref) const {
|
| 301 |
+
return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_);
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
/// Explicit cast to int
|
| 305 |
+
CUTLASS_HOST_DEVICE
|
| 306 |
+
explicit operator int() const {
|
| 307 |
+
return int(get());
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
/// Explicit cast to signed 64-bit integer
|
| 311 |
+
CUTLASS_HOST_DEVICE
|
| 312 |
+
explicit operator int64_t() const {
|
| 313 |
+
return int64_t(get());
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
/// Explicit cast to unsigned 64-bit integer
|
| 317 |
+
CUTLASS_HOST_DEVICE
|
| 318 |
+
explicit operator uint64_t() const {
|
| 319 |
+
return uint64_t(get());
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
/// Explicit cast to float
|
| 323 |
+
CUTLASS_HOST_DEVICE
|
| 324 |
+
explicit operator float() const {
|
| 325 |
+
return float(get());
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
/// Explicit cast to double
|
| 329 |
+
CUTLASS_HOST_DEVICE
|
| 330 |
+
explicit operator double() const {
|
| 331 |
+
return double(get());
|
| 332 |
+
}
|
| 333 |
+
};
|
| 334 |
+
|
| 335 |
+
template <
|
| 336 |
+
typename Element_, /// CUTLASS numeric element type.
|
| 337 |
+
typename Storage_ = /// Underlying storage type. Must be able to hold an integer
|
| 338 |
+
/// number of objects of type Element.
|
| 339 |
+
|
| 340 |
+
#if defined(__CUDA_ARCH__) /// Default size depends on width of atomicCas() overloads.
|
| 341 |
+
#if (__CUDA_ARCH__ >= 700) ///
|
| 342 |
+
uint16_t
|
| 343 |
+
#else
|
| 344 |
+
uint32_t
|
| 345 |
+
#endif
|
| 346 |
+
#else
|
| 347 |
+
uint8_t
|
| 348 |
+
#endif
|
| 349 |
+
,
|
| 350 |
+
class = void
|
| 351 |
+
>
|
| 352 |
+
class SubbyteReference {
|
| 353 |
+
public:
|
| 354 |
+
|
| 355 |
+
using Element = Element_;
|
| 356 |
+
using Storage = Storage_;
|
| 357 |
+
using StoragePointer = Storage *;
|
| 358 |
+
|
| 359 |
+
static_assert(sizeof_bits<Element>::value <= sizeof_bits<Storage>::value,
|
| 360 |
+
"Size of Element must not be greater than Storage.");
|
| 361 |
+
|
| 362 |
+
static_assert(!(sizeof_bits<Storage>::value % sizeof_bits<Element>::value),
|
| 363 |
+
"Storage must be divisible by Element");
|
| 364 |
+
|
| 365 |
+
private:
|
| 366 |
+
|
| 367 |
+
///! Number of elements per storage vector
|
| 368 |
+
int const kElementsPerVector = sizeof_bits<Storage>::value / sizeof_bits<Element>::value;
|
| 369 |
+
|
| 370 |
+
///! Bit mask
|
| 371 |
+
Storage const kMask =
|
| 372 |
+
((sizeof_bits<Element>::value < sizeof_bits<Storage>::value) ?
|
| 373 |
+
(Storage(1) << sizeof_bits<Element>::value) - Storage(1) :
|
| 374 |
+
~Storage(0));
|
| 375 |
+
|
| 376 |
+
private:
|
| 377 |
+
|
| 378 |
+
/// Pointer to array containing element
|
| 379 |
+
StoragePointer ptr_;
|
| 380 |
+
|
| 381 |
+
/// Offset (in units of elements) from pointer.
|
| 382 |
+
///
|
| 383 |
+
/// Invariant: must always be in range [0, kElementsPerVector)
|
| 384 |
+
int offset_;
|
| 385 |
+
|
| 386 |
+
public:
|
| 387 |
+
|
| 388 |
+
CUTLASS_HOST_DEVICE
|
| 389 |
+
SubbyteReference(): ptr_(nullptr), offset_(0) { }
|
| 390 |
+
|
| 391 |
+
/// Constructor
|
| 392 |
+
CUTLASS_HOST_DEVICE
|
| 393 |
+
SubbyteReference(
|
| 394 |
+
Element *ptr, /// pointer to memory
|
| 395 |
+
int64_t offset /// logical offset in units of Element
|
| 396 |
+
):
|
| 397 |
+
ptr_(reinterpret_cast<StoragePointer>(ptr)),
|
| 398 |
+
offset_(0) {
|
| 399 |
+
|
| 400 |
+
int64_t offset_in_vectors = offset / kElementsPerVector;
|
| 401 |
+
int64_t offset_in_elements = offset % kElementsPerVector;
|
| 402 |
+
|
| 403 |
+
ptr_ += offset_in_vectors;
|
| 404 |
+
offset_ = int(offset_in_elements);
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
/// Constructor
|
| 408 |
+
CUTLASS_HOST_DEVICE
|
| 409 |
+
SubbyteReference(
|
| 410 |
+
Element *ptr = nullptr
|
| 411 |
+
): SubbyteReference(ptr, 0) { }
|
| 412 |
+
|
| 413 |
+
/// Gets storage pointer
|
| 414 |
+
CUTLASS_HOST_DEVICE
|
| 415 |
+
StoragePointer storage_pointer() const {
|
| 416 |
+
return ptr_;
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
/// Gets storage pointer
|
| 420 |
+
CUTLASS_HOST_DEVICE
|
| 421 |
+
Element * operator&() const {
|
| 422 |
+
return reinterpret_cast<Element *>(ptr_);
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
/// Gets element offset within storage vector
|
| 426 |
+
CUTLASS_HOST_DEVICE
|
| 427 |
+
int element_offset() const {
|
| 428 |
+
return offset_;
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
/// Unpacks an element from memory
|
| 432 |
+
CUTLASS_HOST_DEVICE
|
| 433 |
+
Element get() const {
|
| 434 |
+
uint8_t const* byte_ptr = reinterpret_cast<uint8_t const*>(ptr_);
|
| 435 |
+
// Convert offset in elements to offset in bytes
|
| 436 |
+
constexpr int elements_per_byte = cutlass::sizeof_bits<uint8_t>::value / cutlass::sizeof_bits<Element>::value;
|
| 437 |
+
byte_ptr += offset_ / elements_per_byte;
|
| 438 |
+
// Offset of element within a byte
|
| 439 |
+
int byte_offset = offset_ % elements_per_byte;
|
| 440 |
+
uint8_t item = uint8_t((*byte_ptr >> (byte_offset * cutlass::sizeof_bits<Element>::value)) & kMask);
|
| 441 |
+
return reinterpret_cast<Element const &>(item);
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
/// Stores an element to memory
|
| 445 |
+
CUTLASS_HOST_DEVICE
|
| 446 |
+
SubbyteReference & set(Element const &x) {
|
| 447 |
+
|
| 448 |
+
Storage item = (reinterpret_cast<Storage const &>(x) & kMask);
|
| 449 |
+
Storage kUpdateMask = Storage(~(kMask << (offset_ * cutlass::sizeof_bits<Element>::value)));
|
| 450 |
+
Storage new_bits = Storage(item << (offset_ * cutlass::sizeof_bits<Element>::value));
|
| 451 |
+
|
| 452 |
+
#if defined(__CUDA_ARCH__)
|
| 453 |
+
|
| 454 |
+
//
|
| 455 |
+
// Homebrew read-modify-write
|
| 456 |
+
//
|
| 457 |
+
Storage original;
|
| 458 |
+
Storage updated;
|
| 459 |
+
|
| 460 |
+
do {
|
| 461 |
+
|
| 462 |
+
original = (*ptr_);
|
| 463 |
+
|
| 464 |
+
updated = Storage((original & kUpdateMask) | new_bits);
|
| 465 |
+
|
| 466 |
+
original = atomicCAS(ptr_, original, updated);
|
| 467 |
+
|
| 468 |
+
} while (updated != original);
|
| 469 |
+
|
| 470 |
+
#else
|
| 471 |
+
|
| 472 |
+
Storage original = (*ptr_);
|
| 473 |
+
Storage updated = Storage((original & kUpdateMask) | new_bits);
|
| 474 |
+
*ptr_ = updated;
|
| 475 |
+
|
| 476 |
+
#endif
|
| 477 |
+
|
| 478 |
+
return *this;
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
////
|
| 482 |
+
|
| 483 |
+
/// Unpacks an element from memory
|
| 484 |
+
CUTLASS_HOST_DEVICE
|
| 485 |
+
operator Element() const {
|
| 486 |
+
return get();
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
/// Stores an element to memory
|
| 490 |
+
CUTLASS_HOST_DEVICE
|
| 491 |
+
SubbyteReference &operator=(Element const & x) {
|
| 492 |
+
return set(x);
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
/// Stores an element to memory
|
| 496 |
+
CUTLASS_HOST_DEVICE
|
| 497 |
+
SubbyteReference &operator=(SubbyteReference const & x) {
|
| 498 |
+
return set(x.get());
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
/// Stores an element to memory
|
| 502 |
+
CUTLASS_HOST_DEVICE
|
| 503 |
+
SubbyteReference &operator=(
|
| 504 |
+
ConstSubbyteReference<Element, Storage> const &x) {
|
| 505 |
+
return set(x.get());
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
/// Adds an offset in units of elements to the reference
|
| 509 |
+
CUTLASS_HOST_DEVICE
|
| 510 |
+
SubbyteReference &operator+=(int offset) {
|
| 511 |
+
|
| 512 |
+
offset += offset_;
|
| 513 |
+
|
| 514 |
+
int offset_in_vectors = offset / kElementsPerVector;
|
| 515 |
+
int offset_in_elements = offset % kElementsPerVector;
|
| 516 |
+
|
| 517 |
+
ptr_ += offset_in_vectors;
|
| 518 |
+
offset_ = offset_in_elements;
|
| 519 |
+
|
| 520 |
+
return *this;
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
/// Adds an offset in units of elements to the reference
|
| 524 |
+
CUTLASS_HOST_DEVICE
|
| 525 |
+
SubbyteReference &operator+=(long long offset) {
|
| 526 |
+
|
| 527 |
+
offset += offset_;
|
| 528 |
+
|
| 529 |
+
long long offset_in_vectors = offset / kElementsPerVector;
|
| 530 |
+
int offset_in_elements = int(offset % kElementsPerVector);
|
| 531 |
+
|
| 532 |
+
ptr_ += offset_in_vectors;
|
| 533 |
+
offset_ = offset_in_elements;
|
| 534 |
+
|
| 535 |
+
return *this;
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
/// Adds an offset in units of elements to the reference
|
| 539 |
+
CUTLASS_HOST_DEVICE
|
| 540 |
+
SubbyteReference &operator-=(int offset) {
|
| 541 |
+
|
| 542 |
+
int offset_in_vectors = offset / kElementsPerVector;
|
| 543 |
+
int offset_in_elements = offset % kElementsPerVector;
|
| 544 |
+
|
| 545 |
+
ptr_ -= offset_in_vectors;
|
| 546 |
+
offset_ -= offset_in_elements;
|
| 547 |
+
|
| 548 |
+
if (offset_ < 0) {
|
| 549 |
+
offset_ += kElementsPerVector;
|
| 550 |
+
--ptr_;
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
return *this;
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
/// Adds an offset in units of elements to the reference
|
| 557 |
+
CUTLASS_HOST_DEVICE
|
| 558 |
+
SubbyteReference &operator-=(long long offset) {
|
| 559 |
+
|
| 560 |
+
long long offset_in_vectors = offset / kElementsPerVector;
|
| 561 |
+
int offset_in_elements = int(offset % kElementsPerVector);
|
| 562 |
+
|
| 563 |
+
ptr_ -= offset_in_vectors;
|
| 564 |
+
offset_ -= offset_in_elements;
|
| 565 |
+
|
| 566 |
+
if (offset_ < 0) {
|
| 567 |
+
offset_ += kElementsPerVector;
|
| 568 |
+
--ptr_;
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
return *this;
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
/// Returns a reference to an element with a given offset from the current reference
|
| 575 |
+
CUTLASS_HOST_DEVICE
|
| 576 |
+
SubbyteReference operator+(int offset) const {
|
| 577 |
+
|
| 578 |
+
SubbyteReference ref(ptr_, offset_);
|
| 579 |
+
ref += offset;
|
| 580 |
+
|
| 581 |
+
return ref;
|
| 582 |
+
}
|
| 583 |
+
|
| 584 |
+
/// Returns a reference to an element with a given offset from the current reference
|
| 585 |
+
CUTLASS_HOST_DEVICE
|
| 586 |
+
SubbyteReference operator+(long long offset) const {
|
| 587 |
+
|
| 588 |
+
SubbyteReference ref(ptr_, offset_);
|
| 589 |
+
ref += offset;
|
| 590 |
+
|
| 591 |
+
return ref;
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
/// Returns a reference to an element with a given offset from the current reference
|
| 595 |
+
CUTLASS_HOST_DEVICE
|
| 596 |
+
SubbyteReference operator-(int offset) const {
|
| 597 |
+
|
| 598 |
+
SubbyteReference ref(ptr_, offset_);
|
| 599 |
+
ref -= offset;
|
| 600 |
+
|
| 601 |
+
return ref;
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
/// Returns a reference to an element with a given offset from the current reference
|
| 605 |
+
CUTLASS_HOST_DEVICE
|
| 606 |
+
SubbyteReference operator-=(long long offset) const {
|
| 607 |
+
|
| 608 |
+
SubbyteReference ref(ptr_, offset_);
|
| 609 |
+
ref -= offset;
|
| 610 |
+
|
| 611 |
+
return ref;
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
/// Computes the difference in elements between references
|
| 615 |
+
CUTLASS_HOST_DEVICE
|
| 616 |
+
ptrdiff_t operator-(SubbyteReference ref) const {
|
| 617 |
+
return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_);
|
| 618 |
+
}
|
| 619 |
+
|
| 620 |
+
/// Explicit cast to int
|
| 621 |
+
CUTLASS_HOST_DEVICE
|
| 622 |
+
explicit operator int() const {
|
| 623 |
+
return int(get());
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
/// Explicit cast to signed 64-bit integer
|
| 627 |
+
CUTLASS_HOST_DEVICE
|
| 628 |
+
explicit operator int64_t() const {
|
| 629 |
+
return int64_t(get());
|
| 630 |
+
}
|
| 631 |
+
|
| 632 |
+
/// Explicit cast to unsigned 64-bit integer
|
| 633 |
+
CUTLASS_HOST_DEVICE
|
| 634 |
+
explicit operator uint64_t() const {
|
| 635 |
+
return uint64_t(get());
|
| 636 |
+
}
|
| 637 |
+
|
| 638 |
+
/// Explicit cast to float
|
| 639 |
+
CUTLASS_HOST_DEVICE
|
| 640 |
+
explicit operator float() const {
|
| 641 |
+
return float(get());
|
| 642 |
+
}
|
| 643 |
+
|
| 644 |
+
/// Explicit cast to double
|
| 645 |
+
CUTLASS_HOST_DEVICE
|
| 646 |
+
explicit operator double() const {
|
| 647 |
+
return double(get());
|
| 648 |
+
}
|
| 649 |
+
};
|
| 650 |
+
|
| 651 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 652 |
+
|
| 653 |
+
template<typename T> using _war = T;
|
| 654 |
+
template <
|
| 655 |
+
typename Element_, /// CUTLASS numeric element type.
|
| 656 |
+
typename Storage_ /// Underlying basic storage type.
|
| 657 |
+
>
|
| 658 |
+
class SubbyteReference<Element_, Storage_,
|
| 659 |
+
typename platform::enable_if<sizeof_bits<Storage_>::value % sizeof_bits<Element_>::value != 0>::type> {
|
| 660 |
+
public:
|
| 661 |
+
|
| 662 |
+
using Element = Element_;
|
| 663 |
+
/// Note: It's possible that StorageUnit is not divisible by Element.
|
| 664 |
+
/// For example, an Element instance might be stored across 2 StorageUnit instances.
|
| 665 |
+
/// Thus, CUTLASS needs a storage vector to hold an integer number of Element instances.
|
| 666 |
+
|
| 667 |
+
using StorageUnit = Storage_;
|
| 668 |
+
private:
|
| 669 |
+
using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator<Element, StorageUnit>;
|
| 670 |
+
public:
|
| 671 |
+
static int const kBitsStoredVec = StorageContainerCalculator::kContainerTypeNumBits;
|
| 672 |
+
static int const kNumStorageUnitPerStoredVec = StorageContainerCalculator::kContainerTypeNumStorageUnit;
|
| 673 |
+
|
| 674 |
+
using StorageVec = StorageUnit[kNumStorageUnitPerStoredVec];
|
| 675 |
+
using StorageVecPointer = StorageVec *;
|
| 676 |
+
|
| 677 |
+
using CudaAtomicType = typename platform::conditional<
|
| 678 |
+
sizeof_bits<StorageUnit>::value == 16,
|
| 679 |
+
uint32_t,
|
| 680 |
+
uint64_t
|
| 681 |
+
>::type;
|
| 682 |
+
|
| 683 |
+
static_assert(sizeof_bits<Element>::value <= sizeof_bits<StorageVec>::value,
|
| 684 |
+
"Size of Element must not be greater than StorageVec.");
|
| 685 |
+
|
| 686 |
+
static_assert(!(sizeof_bits<StorageVec>::value % sizeof_bits<Element>::value),
|
| 687 |
+
"StorageVec must be divisible by Element");
|
| 688 |
+
|
| 689 |
+
private:
|
| 690 |
+
|
| 691 |
+
///! Number of elements per storage vector
|
| 692 |
+
int const kElementsPerVector = sizeof_bits<StorageVec>::value / sizeof_bits<Element>::value;
|
| 693 |
+
|
| 694 |
+
///! Bit mask for storage unit.
|
| 695 |
+
StorageUnit const kMask = (StorageUnit(1) << sizeof_bits<Element>::value) - StorageUnit(1);
|
| 696 |
+
|
| 697 |
+
/// Pointer to array containing element
|
| 698 |
+
_war<StorageVecPointer> ptr_;
|
| 699 |
+
|
| 700 |
+
/// Offset (in units of elements) from pointer.
|
| 701 |
+
///
|
| 702 |
+
/// Invariant: must always be in range [0, kElementsPerVector)
|
| 703 |
+
int offset_;
|
| 704 |
+
|
| 705 |
+
/// Element may be stored across 2 storage unit.
|
| 706 |
+
/// Low storage unit index in StorageVec
|
| 707 |
+
/// High storage unit index in StorageVec
|
| 708 |
+
int low_storage_unit_idx_;
|
| 709 |
+
int high_storage_unit_idx_;
|
| 710 |
+
|
| 711 |
+
/// Full Mask to extract the entire element
|
| 712 |
+
uint64_t full_element_mask_;
|
| 713 |
+
|
| 714 |
+
/// Mask to extract the Element from Low storage unit and High storage unit.
|
| 715 |
+
StorageUnit low_storage_mask_;
|
| 716 |
+
StorageUnit high_storage_mask_;
|
| 717 |
+
|
| 718 |
+
/// Start bit index inside the storage unit.
|
| 719 |
+
int start_bit_idx_;
|
| 720 |
+
|
| 721 |
+
private:
|
| 722 |
+
|
| 723 |
+
CUTLASS_HOST_DEVICE
|
| 724 |
+
void update_element_status() {
|
| 725 |
+
int num_bits = offset_ * sizeof_bits<Element>::value;
|
| 726 |
+
|
| 727 |
+
start_bit_idx_ = num_bits % sizeof_bits<StorageUnit>::value;
|
| 728 |
+
|
| 729 |
+
low_storage_unit_idx_ = num_bits / sizeof_bits<StorageUnit>::value;
|
| 730 |
+
high_storage_unit_idx_ = sizeof_bits<StorageUnit>::value - (start_bit_idx_) < sizeof_bits<Element>::value
|
| 731 |
+
? low_storage_unit_idx_ + 1 : low_storage_unit_idx_;
|
| 732 |
+
|
| 733 |
+
full_element_mask_ = uint64_t(kMask) << start_bit_idx_;
|
| 734 |
+
low_storage_mask_ = StorageUnit(full_element_mask_ & ~StorageUnit(0));
|
| 735 |
+
high_storage_mask_ = StorageUnit((full_element_mask_ >> sizeof_bits<StorageUnit>::value) & ~StorageUnit(0));
|
| 736 |
+
}
|
| 737 |
+
|
| 738 |
+
public:
|
| 739 |
+
|
| 740 |
+
CUTLASS_HOST_DEVICE
|
| 741 |
+
SubbyteReference(): ptr_(nullptr), offset_(0) { }
|
| 742 |
+
|
| 743 |
+
/// Constructor
|
| 744 |
+
CUTLASS_HOST_DEVICE
|
| 745 |
+
SubbyteReference(
|
| 746 |
+
Element *ptr, /// pointer to memory
|
| 747 |
+
int64_t offset /// logical offset in units of Element
|
| 748 |
+
):
|
| 749 |
+
ptr_(reinterpret_cast<StorageVecPointer>(ptr)),
|
| 750 |
+
offset_(0) {
|
| 751 |
+
int64_t offset_in_vectors = offset / kElementsPerVector;
|
| 752 |
+
int64_t offset_in_elements = offset % kElementsPerVector;
|
| 753 |
+
|
| 754 |
+
ptr_ += offset_in_vectors;
|
| 755 |
+
offset_ = int(offset_in_elements);
|
| 756 |
+
|
| 757 |
+
update_element_status();
|
| 758 |
+
}
|
| 759 |
+
|
| 760 |
+
/// Constructor
|
| 761 |
+
CUTLASS_HOST_DEVICE
|
| 762 |
+
SubbyteReference(
|
| 763 |
+
Element *ptr = nullptr
|
| 764 |
+
): SubbyteReference(ptr, 0) { }
|
| 765 |
+
|
| 766 |
+
/// Gets StorageVec pointer
|
| 767 |
+
CUTLASS_HOST_DEVICE
|
| 768 |
+
StorageVecPointer storage_pointer() const {
|
| 769 |
+
return ptr_;
|
| 770 |
+
}
|
| 771 |
+
|
| 772 |
+
/// Gets StorageVec pointer
|
| 773 |
+
CUTLASS_HOST_DEVICE
|
| 774 |
+
Element * operator&() const {
|
| 775 |
+
return reinterpret_cast<Element *>(ptr_);
|
| 776 |
+
}
|
| 777 |
+
|
| 778 |
+
/// Gets element offset within StorageVec vector
|
| 779 |
+
CUTLASS_HOST_DEVICE
|
| 780 |
+
int element_offset() const {
|
| 781 |
+
return offset_;
|
| 782 |
+
}
|
| 783 |
+
|
| 784 |
+
/// Unpacks an element from memory
|
| 785 |
+
CUTLASS_HOST_DEVICE
|
| 786 |
+
Element get() const {
|
| 787 |
+
StorageUnit low_bits = (*ptr_)[low_storage_unit_idx_] & low_storage_mask_;
|
| 788 |
+
StorageUnit high_bits = low_storage_unit_idx_ != high_storage_unit_idx_ ? (*ptr_)[high_storage_unit_idx_] & high_storage_mask_ : 0;
|
| 789 |
+
|
| 790 |
+
uint64_t full_item = ((uint64_t)high_bits << sizeof_bits<StorageUnit>::value) | low_bits;
|
| 791 |
+
uint8_t result = uint8_t(full_item >> start_bit_idx_);
|
| 792 |
+
|
| 793 |
+
return reinterpret_cast<Element const &>(result);
|
| 794 |
+
}
|
| 795 |
+
|
| 796 |
+
/// Stores an element to memory
|
| 797 |
+
CUTLASS_HOST_DEVICE
|
| 798 |
+
SubbyteReference & set(Element const &x) {
|
| 799 |
+
|
| 800 |
+
uint64_t item = static_cast<uint64_t>((reinterpret_cast<uint8_t const &>(x) & kMask)) << start_bit_idx_;
|
| 801 |
+
|
| 802 |
+
StorageUnit low_new_bits = StorageUnit(item & ~StorageUnit(0));
|
| 803 |
+
StorageUnit high_new_bits = StorageUnit(item >> sizeof_bits<StorageUnit>::value);
|
| 804 |
+
|
| 805 |
+
StorageUnit const kLowUpdateMask = StorageUnit((~full_element_mask_) & (~StorageUnit(0)));
|
| 806 |
+
StorageUnit const kHighUpdateMask = StorageUnit(((~full_element_mask_) >> sizeof_bits<StorageUnit>::value) & (~StorageUnit(0)));
|
| 807 |
+
|
| 808 |
+
#if defined(__CUDA_ARCH__)
|
| 809 |
+
//
|
| 810 |
+
// Homebrew read-modify-write
|
| 811 |
+
//
|
| 812 |
+
if(high_storage_unit_idx_ != low_storage_unit_idx_){
|
| 813 |
+
/// Only need update 2 storage unit at once.
|
| 814 |
+
/// consider misaligned address issue, we need to do atomicCAS twice
|
| 815 |
+
StorageUnit original_low_bits, original_high_bits, update_low_bits, update_high_bits;
|
| 816 |
+
do {
|
| 817 |
+
original_low_bits = ((*ptr_)[low_storage_unit_idx_]);
|
| 818 |
+
update_low_bits = (original_low_bits & kLowUpdateMask) | low_new_bits;
|
| 819 |
+
original_low_bits = atomicCAS(&((*ptr_)[low_storage_unit_idx_]), original_low_bits, update_low_bits);
|
| 820 |
+
} while (update_low_bits != original_low_bits);
|
| 821 |
+
do {
|
| 822 |
+
original_high_bits = ((*ptr_)[high_storage_unit_idx_]);
|
| 823 |
+
update_high_bits = (original_high_bits & kHighUpdateMask) | high_new_bits;
|
| 824 |
+
original_high_bits = atomicCAS(&((*ptr_)[high_storage_unit_idx_]), original_high_bits, update_high_bits);
|
| 825 |
+
} while (update_high_bits != original_high_bits);
|
| 826 |
+
}
|
| 827 |
+
else {
|
| 828 |
+
/// Only need update 1 storage unit.
|
| 829 |
+
StorageUnit original, updated;
|
| 830 |
+
do {
|
| 831 |
+
original = ((*ptr_)[low_storage_unit_idx_]);
|
| 832 |
+
|
| 833 |
+
updated = (original & kLowUpdateMask) | low_new_bits;
|
| 834 |
+
|
| 835 |
+
original = atomicCAS(&((*ptr_)[low_storage_unit_idx_]), original, updated);
|
| 836 |
+
|
| 837 |
+
} while (updated != original);
|
| 838 |
+
}
|
| 839 |
+
#else
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
StorageUnit update_low_bits = ((*ptr_)[low_storage_unit_idx_] & kLowUpdateMask) | low_new_bits;
|
| 843 |
+
StorageUnit update_high_bits = ((*ptr_)[high_storage_unit_idx_] & kHighUpdateMask) | high_new_bits;
|
| 844 |
+
|
| 845 |
+
(*ptr_)[low_storage_unit_idx_] = update_low_bits;
|
| 846 |
+
|
| 847 |
+
if(low_storage_unit_idx_ != high_storage_unit_idx_)
|
| 848 |
+
(*ptr_)[high_storage_unit_idx_] = update_high_bits;
|
| 849 |
+
#endif
|
| 850 |
+
|
| 851 |
+
return *this;
|
| 852 |
+
}
|
| 853 |
+
|
| 854 |
+
////
|
| 855 |
+
|
| 856 |
+
/// Unpacks an element from memory
|
| 857 |
+
CUTLASS_HOST_DEVICE
|
| 858 |
+
operator Element() const {
|
| 859 |
+
return get();
|
| 860 |
+
}
|
| 861 |
+
|
| 862 |
+
/// Stores an element to memory
|
| 863 |
+
CUTLASS_HOST_DEVICE
|
| 864 |
+
SubbyteReference &operator=(Element const & x) {
|
| 865 |
+
return set(x);
|
| 866 |
+
}
|
| 867 |
+
|
| 868 |
+
/// Stores an element to memory
|
| 869 |
+
CUTLASS_HOST_DEVICE
|
| 870 |
+
SubbyteReference &operator=(SubbyteReference const & x) {
|
| 871 |
+
return set(x.get());
|
| 872 |
+
}
|
| 873 |
+
|
| 874 |
+
/// Stores an element to memory
|
| 875 |
+
CUTLASS_HOST_DEVICE
|
| 876 |
+
SubbyteReference &operator=(
|
| 877 |
+
ConstSubbyteReference<Element, StorageVec> const &x) {
|
| 878 |
+
return set(x.get());
|
| 879 |
+
}
|
| 880 |
+
|
| 881 |
+
/// Adds an offset in units of elements to the reference
|
| 882 |
+
CUTLASS_HOST_DEVICE
|
| 883 |
+
SubbyteReference &operator+=(int offset) {
|
| 884 |
+
|
| 885 |
+
offset += offset_;
|
| 886 |
+
|
| 887 |
+
int offset_in_vectors = offset / kElementsPerVector;
|
| 888 |
+
int offset_in_elements = offset % kElementsPerVector;
|
| 889 |
+
|
| 890 |
+
ptr_ += offset_in_vectors;
|
| 891 |
+
offset_ = offset_in_elements;
|
| 892 |
+
|
| 893 |
+
update_element_status();
|
| 894 |
+
|
| 895 |
+
return *this;
|
| 896 |
+
}
|
| 897 |
+
|
| 898 |
+
/// Adds an offset in units of elements to the reference
|
| 899 |
+
CUTLASS_HOST_DEVICE
|
| 900 |
+
SubbyteReference &operator+=(long long offset) {
|
| 901 |
+
|
| 902 |
+
offset += offset_;
|
| 903 |
+
|
| 904 |
+
long long offset_in_vectors = offset / kElementsPerVector;
|
| 905 |
+
int offset_in_elements = int(offset % kElementsPerVector);
|
| 906 |
+
|
| 907 |
+
ptr_ += offset_in_vectors;
|
| 908 |
+
offset_ = offset_in_elements;
|
| 909 |
+
|
| 910 |
+
update_element_status();
|
| 911 |
+
|
| 912 |
+
return *this;
|
| 913 |
+
}
|
| 914 |
+
|
| 915 |
+
/// Adds an offset in units of elements to the reference
|
| 916 |
+
CUTLASS_HOST_DEVICE
|
| 917 |
+
SubbyteReference &operator-=(int offset) {
|
| 918 |
+
|
| 919 |
+
int offset_in_vectors = offset / kElementsPerVector;
|
| 920 |
+
int offset_in_elements = offset % kElementsPerVector;
|
| 921 |
+
|
| 922 |
+
ptr_ -= offset_in_vectors;
|
| 923 |
+
offset_ -= offset_in_elements;
|
| 924 |
+
|
| 925 |
+
if (offset_ < 0) {
|
| 926 |
+
offset_ += kElementsPerVector;
|
| 927 |
+
--ptr_;
|
| 928 |
+
}
|
| 929 |
+
|
| 930 |
+
update_element_status();
|
| 931 |
+
return *this;
|
| 932 |
+
}
|
| 933 |
+
|
| 934 |
+
/// Adds an offset in units of elements to the reference
|
| 935 |
+
CUTLASS_HOST_DEVICE
|
| 936 |
+
SubbyteReference &operator-=(long long offset) {
|
| 937 |
+
|
| 938 |
+
long long offset_in_vectors = offset / kElementsPerVector;
|
| 939 |
+
int offset_in_elements = int(offset % kElementsPerVector);
|
| 940 |
+
|
| 941 |
+
ptr_ -= offset_in_vectors;
|
| 942 |
+
offset_ -= offset_in_elements;
|
| 943 |
+
|
| 944 |
+
if (offset_ < 0) {
|
| 945 |
+
offset_ += kElementsPerVector;
|
| 946 |
+
--ptr_;
|
| 947 |
+
}
|
| 948 |
+
|
| 949 |
+
update_element_status();
|
| 950 |
+
return *this;
|
| 951 |
+
}
|
| 952 |
+
|
| 953 |
+
/// Returns a reference to an element with a given offset from the current reference
|
| 954 |
+
CUTLASS_HOST_DEVICE
|
| 955 |
+
SubbyteReference operator+(int offset) const {
|
| 956 |
+
|
| 957 |
+
SubbyteReference ref(ptr_, offset_);
|
| 958 |
+
ref += offset;
|
| 959 |
+
|
| 960 |
+
return ref;
|
| 961 |
+
}
|
| 962 |
+
|
| 963 |
+
/// Returns a reference to an element with a given offset from the current reference
|
| 964 |
+
CUTLASS_HOST_DEVICE
|
| 965 |
+
SubbyteReference operator+(long long offset) const {
|
| 966 |
+
|
| 967 |
+
SubbyteReference ref(ptr_, offset_);
|
| 968 |
+
ref += offset;
|
| 969 |
+
|
| 970 |
+
return ref;
|
| 971 |
+
}
|
| 972 |
+
|
| 973 |
+
/// Returns a reference to an element with a given offset from the current reference
|
| 974 |
+
CUTLASS_HOST_DEVICE
|
| 975 |
+
SubbyteReference operator-(int offset) const {
|
| 976 |
+
|
| 977 |
+
SubbyteReference ref(ptr_, offset_);
|
| 978 |
+
ref -= offset;
|
| 979 |
+
|
| 980 |
+
return ref;
|
| 981 |
+
}
|
| 982 |
+
|
| 983 |
+
/// Returns a reference to an element with a given offset from the current reference
|
| 984 |
+
CUTLASS_HOST_DEVICE
|
| 985 |
+
SubbyteReference operator-=(long long offset) const {
|
| 986 |
+
|
| 987 |
+
SubbyteReference ref(ptr_, offset_);
|
| 988 |
+
ref -= offset;
|
| 989 |
+
|
| 990 |
+
return ref;
|
| 991 |
+
}
|
| 992 |
+
|
| 993 |
+
/// Computes the difference in elements between references
|
| 994 |
+
CUTLASS_HOST_DEVICE
|
| 995 |
+
ptrdiff_t operator-(SubbyteReference ref) const {
|
| 996 |
+
return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_);
|
| 997 |
+
}
|
| 998 |
+
|
| 999 |
+
/// Explicit cast to int
|
| 1000 |
+
CUTLASS_HOST_DEVICE
|
| 1001 |
+
explicit operator int() const {
|
| 1002 |
+
return int(get());
|
| 1003 |
+
}
|
| 1004 |
+
|
| 1005 |
+
/// Explicit cast to signed 64-bit integer
|
| 1006 |
+
CUTLASS_HOST_DEVICE
|
| 1007 |
+
explicit operator int64_t() const {
|
| 1008 |
+
return int64_t(get());
|
| 1009 |
+
}
|
| 1010 |
+
|
| 1011 |
+
/// Explicit cast to unsigned 64-bit integer
|
| 1012 |
+
CUTLASS_HOST_DEVICE
|
| 1013 |
+
explicit operator uint64_t() const {
|
| 1014 |
+
return uint64_t(get());
|
| 1015 |
+
}
|
| 1016 |
+
|
| 1017 |
+
/// Explicit cast to float
|
| 1018 |
+
CUTLASS_HOST_DEVICE
|
| 1019 |
+
explicit operator float() const {
|
| 1020 |
+
return float(get());
|
| 1021 |
+
}
|
| 1022 |
+
|
| 1023 |
+
/// Explicit cast to double
|
| 1024 |
+
CUTLASS_HOST_DEVICE
|
| 1025 |
+
explicit operator double() const {
|
| 1026 |
+
return double(get());
|
| 1027 |
+
}
|
| 1028 |
+
};
|
| 1029 |
+
|
| 1030 |
+
template<typename T> using _war = T;
|
| 1031 |
+
template <
|
| 1032 |
+
typename Element_, /// CUTLASS numeric element type.
|
| 1033 |
+
typename Storage_ /// Underlying storage type. Must be able to hold an integer
|
| 1034 |
+
>
|
| 1035 |
+
class ConstSubbyteReference<Element_, Storage_,
|
| 1036 |
+
typename platform::enable_if<sizeof_bits<Storage_>::value % sizeof_bits<Element_>::value != 0>::type> {
|
| 1037 |
+
public:
|
| 1038 |
+
|
| 1039 |
+
using Element = Element_;
|
| 1040 |
+
///! Note: Storage unit could not be divisibale by Element,
|
| 1041 |
+
/// Type element may be stored across 2 storage units, so need a storage vector to hold integer
|
| 1042 |
+
/// number of objects of type Element.
|
| 1043 |
+
using StorageUnit = Storage_;
|
| 1044 |
+
static int const kBitsStoredVec = cutlass::lcm_cxx11(sizeof_bits<Element>::value, sizeof_bits<StorageUnit>::value);
|
| 1045 |
+
static int const kNumStorageUnitPerStoredVec = kBitsStoredVec / sizeof_bits<StorageUnit>::value;
|
| 1046 |
+
|
| 1047 |
+
using StorageVec = StorageUnit[kNumStorageUnitPerStoredVec];
|
| 1048 |
+
using StorageVecPointer = StorageVec const *;
|
| 1049 |
+
|
| 1050 |
+
using CudaAtomicType = typename platform::conditional<
|
| 1051 |
+
sizeof_bits<StorageUnit>::value == 16,
|
| 1052 |
+
uint32_t,
|
| 1053 |
+
uint64_t
|
| 1054 |
+
>::type;
|
| 1055 |
+
|
| 1056 |
+
static_assert(sizeof_bits<Element>::value <= sizeof_bits<StorageVec>::value,
|
| 1057 |
+
"Size of Element must not be greater than StorageVec.");
|
| 1058 |
+
|
| 1059 |
+
static_assert(!(sizeof_bits<StorageVec>::value % sizeof_bits<Element>::value),
|
| 1060 |
+
"StorageVec must be divisible by Element");
|
| 1061 |
+
|
| 1062 |
+
private:
|
| 1063 |
+
|
| 1064 |
+
///! Number of elements per storage vector
|
| 1065 |
+
int const kElementsPerVector = sizeof_bits<StorageVec>::value / sizeof_bits<Element>::value;
|
| 1066 |
+
|
| 1067 |
+
///! Bit mask for storage unit.
|
| 1068 |
+
StorageUnit const kMask = (StorageUnit(1) << sizeof_bits<Element>::value) - StorageUnit(1);
|
| 1069 |
+
|
| 1070 |
+
/// Pointer to array containing element
|
| 1071 |
+
_war<StorageVecPointer> ptr_;
|
| 1072 |
+
|
| 1073 |
+
/// Offset (in units of elements) from pointer.
|
| 1074 |
+
///
|
| 1075 |
+
/// Invariant: must always be in range [0, kElementsPerVector)
|
| 1076 |
+
int offset_;
|
| 1077 |
+
|
| 1078 |
+
/// Element may be stored across 2 storage unit.
|
| 1079 |
+
/// Low storage unit index in StorageVec
|
| 1080 |
+
/// High storage unit index in StorageVec
|
| 1081 |
+
int low_storage_unit_idx_;
|
| 1082 |
+
int high_storage_unit_idx_;
|
| 1083 |
+
|
| 1084 |
+
/// Full Mask to extract the entire element
|
| 1085 |
+
uint64_t full_element_mask_;
|
| 1086 |
+
|
| 1087 |
+
/// Mask to extract the Element from Low storage unit and High storage unit.
|
| 1088 |
+
StorageUnit low_storage_mask_;
|
| 1089 |
+
StorageUnit high_storage_mask_;
|
| 1090 |
+
|
| 1091 |
+
/// Start bit index inside the storage unit.
|
| 1092 |
+
int start_bit_idx_;
|
| 1093 |
+
|
| 1094 |
+
private:
|
| 1095 |
+
|
| 1096 |
+
CUTLASS_HOST_DEVICE
|
| 1097 |
+
void update_element_status() {
|
| 1098 |
+
int num_bits = offset_ * sizeof_bits<Element>::value;
|
| 1099 |
+
|
| 1100 |
+
start_bit_idx_ = num_bits % sizeof_bits<StorageUnit>::value;
|
| 1101 |
+
|
| 1102 |
+
low_storage_unit_idx_ = num_bits / sizeof_bits<StorageUnit>::value;
|
| 1103 |
+
high_storage_unit_idx_ = sizeof_bits<StorageUnit>::value - (start_bit_idx_) < sizeof_bits<Element>::value
|
| 1104 |
+
? low_storage_unit_idx_ + 1 : low_storage_unit_idx_;
|
| 1105 |
+
|
| 1106 |
+
full_element_mask_ = uint64_t(kMask) << start_bit_idx_;
|
| 1107 |
+
low_storage_mask_ = StorageUnit(full_element_mask_ & ~StorageUnit(0));
|
| 1108 |
+
high_storage_mask_ = StorageUnit((full_element_mask_ >> sizeof_bits<StorageUnit>::value) & ~StorageUnit(0));
|
| 1109 |
+
}
|
| 1110 |
+
|
| 1111 |
+
public:
|
| 1112 |
+
|
| 1113 |
+
CUTLASS_HOST_DEVICE
|
| 1114 |
+
ConstSubbyteReference(): ptr_(nullptr), offset_(0) { }
|
| 1115 |
+
|
| 1116 |
+
/// Constructor
|
| 1117 |
+
CUTLASS_HOST_DEVICE
|
| 1118 |
+
ConstSubbyteReference(
|
| 1119 |
+
Element const *ptr, /// pointer to memory
|
| 1120 |
+
int64_t offset /// logical offset in units of Element
|
| 1121 |
+
):
|
| 1122 |
+
ptr_(reinterpret_cast<StorageVecPointer>(ptr)),
|
| 1123 |
+
offset_(0) {
|
| 1124 |
+
|
| 1125 |
+
int64_t offset_in_vectors = offset / kElementsPerVector;
|
| 1126 |
+
int64_t offset_in_elements = offset % kElementsPerVector;
|
| 1127 |
+
|
| 1128 |
+
ptr_ += offset_in_vectors;
|
| 1129 |
+
offset_ = int(offset_in_elements);
|
| 1130 |
+
|
| 1131 |
+
update_element_status();
|
| 1132 |
+
}
|
| 1133 |
+
|
| 1134 |
+
/// Constructor
|
| 1135 |
+
CUTLASS_HOST_DEVICE
|
| 1136 |
+
ConstSubbyteReference(
|
| 1137 |
+
Element *ptr = nullptr
|
| 1138 |
+
): ConstSubbyteReference(ptr, 0) { }
|
| 1139 |
+
|
| 1140 |
+
/// Gets storage pointer
|
| 1141 |
+
CUTLASS_HOST_DEVICE
|
| 1142 |
+
StorageVecPointer storage_pointer() const {
|
| 1143 |
+
return ptr_;
|
| 1144 |
+
}
|
| 1145 |
+
|
| 1146 |
+
/// Gets element offset within storage vector
|
| 1147 |
+
CUTLASS_HOST_DEVICE
|
| 1148 |
+
int element_offset() const {
|
| 1149 |
+
return offset_;
|
| 1150 |
+
}
|
| 1151 |
+
|
| 1152 |
+
/// Unpacks an element from memory
|
| 1153 |
+
CUTLASS_HOST_DEVICE
|
| 1154 |
+
Element get() const {
|
| 1155 |
+
StorageUnit low_bits = (*ptr_)[low_storage_unit_idx_] & low_storage_mask_;
|
| 1156 |
+
StorageUnit high_bits = low_storage_unit_idx_ != high_storage_unit_idx_ ? (*ptr_)[high_storage_unit_idx_] & high_storage_mask_ : 0;
|
| 1157 |
+
|
| 1158 |
+
uint64_t full_item = ((uint64_t)high_bits << sizeof_bits<StorageUnit>::value) | low_bits;
|
| 1159 |
+
uint8_t result = uint8_t(full_item >> start_bit_idx_);
|
| 1160 |
+
|
| 1161 |
+
return reinterpret_cast<Element const &>(result);
|
| 1162 |
+
}
|
| 1163 |
+
|
| 1164 |
+
/// Unpacks an element from memory
|
| 1165 |
+
CUTLASS_HOST_DEVICE
|
| 1166 |
+
operator Element() const {
|
| 1167 |
+
return get();
|
| 1168 |
+
}
|
| 1169 |
+
|
| 1170 |
+
/// Adds an offset in units of elements to the reference
|
| 1171 |
+
CUTLASS_HOST_DEVICE
|
| 1172 |
+
ConstSubbyteReference &operator+=(int offset) {
|
| 1173 |
+
|
| 1174 |
+
offset += offset_;
|
| 1175 |
+
|
| 1176 |
+
int offset_in_vectors = offset / kElementsPerVector;
|
| 1177 |
+
int offset_in_elements = offset % kElementsPerVector;
|
| 1178 |
+
|
| 1179 |
+
ptr_ += offset_in_vectors;
|
| 1180 |
+
offset_ = offset_in_elements;
|
| 1181 |
+
|
| 1182 |
+
update_element_status();
|
| 1183 |
+
|
| 1184 |
+
return *this;
|
| 1185 |
+
}
|
| 1186 |
+
|
| 1187 |
+
/// Adds an offset in units of elements to the reference
|
| 1188 |
+
CUTLASS_HOST_DEVICE
|
| 1189 |
+
ConstSubbyteReference &operator+=(long long offset) {
|
| 1190 |
+
|
| 1191 |
+
offset += offset_;
|
| 1192 |
+
|
| 1193 |
+
long long offset_in_vectors = offset / kElementsPerVector;
|
| 1194 |
+
int offset_in_elements = int(offset % kElementsPerVector);
|
| 1195 |
+
|
| 1196 |
+
ptr_ += offset_in_vectors;
|
| 1197 |
+
offset_ = offset_in_elements;
|
| 1198 |
+
|
| 1199 |
+
update_element_status();
|
| 1200 |
+
|
| 1201 |
+
return *this;
|
| 1202 |
+
}
|
| 1203 |
+
|
| 1204 |
+
/// Adds an offset in units of elements to the reference
|
| 1205 |
+
CUTLASS_HOST_DEVICE
|
| 1206 |
+
ConstSubbyteReference &operator-=(int offset) {
|
| 1207 |
+
|
| 1208 |
+
int offset_in_vectors = offset / kElementsPerVector;
|
| 1209 |
+
int offset_in_elements = offset % kElementsPerVector;
|
| 1210 |
+
|
| 1211 |
+
ptr_ -= offset_in_vectors;
|
| 1212 |
+
offset_ -= offset_in_elements;
|
| 1213 |
+
|
| 1214 |
+
if (offset_ < 0) {
|
| 1215 |
+
offset_ += kElementsPerVector;
|
| 1216 |
+
--ptr_;
|
| 1217 |
+
}
|
| 1218 |
+
|
| 1219 |
+
update_element_status();
|
| 1220 |
+
|
| 1221 |
+
return *this;
|
| 1222 |
+
}
|
| 1223 |
+
|
| 1224 |
+
/// Adds an offset in units of elements to the reference
|
| 1225 |
+
CUTLASS_HOST_DEVICE
|
| 1226 |
+
ConstSubbyteReference &operator-=(long long offset) {
|
| 1227 |
+
|
| 1228 |
+
long long offset_in_vectors = offset / kElementsPerVector;
|
| 1229 |
+
int offset_in_elements = int(offset % kElementsPerVector);
|
| 1230 |
+
|
| 1231 |
+
ptr_ -= offset_in_vectors;
|
| 1232 |
+
offset_ -= offset_in_elements;
|
| 1233 |
+
|
| 1234 |
+
if (offset_ < 0) {
|
| 1235 |
+
offset_ += kElementsPerVector;
|
| 1236 |
+
--ptr_;
|
| 1237 |
+
}
|
| 1238 |
+
|
| 1239 |
+
update_element_status();
|
| 1240 |
+
|
| 1241 |
+
return *this;
|
| 1242 |
+
}
|
| 1243 |
+
|
| 1244 |
+
/// Returns a reference to an element with a given offset from the current reference
|
| 1245 |
+
CUTLASS_HOST_DEVICE
|
| 1246 |
+
ConstSubbyteReference operator+(int offset) const {
|
| 1247 |
+
|
| 1248 |
+
ConstSubbyteReference ref(ptr_, offset_);
|
| 1249 |
+
ref += offset;
|
| 1250 |
+
|
| 1251 |
+
return ref;
|
| 1252 |
+
}
|
| 1253 |
+
|
| 1254 |
+
/// Returns a reference to an element with a given offset from the current reference
|
| 1255 |
+
CUTLASS_HOST_DEVICE
|
| 1256 |
+
ConstSubbyteReference operator+(long long offset) const {
|
| 1257 |
+
|
| 1258 |
+
ConstSubbyteReference ref(ptr_, offset_);
|
| 1259 |
+
ref += offset;
|
| 1260 |
+
|
| 1261 |
+
return ref;
|
| 1262 |
+
}
|
| 1263 |
+
|
| 1264 |
+
/// Returns a reference to an element with a given offset from the current reference
|
| 1265 |
+
CUTLASS_HOST_DEVICE
|
| 1266 |
+
ConstSubbyteReference operator-(int offset) const {
|
| 1267 |
+
|
| 1268 |
+
ConstSubbyteReference ref(ptr_, offset_);
|
| 1269 |
+
ref -= offset;
|
| 1270 |
+
|
| 1271 |
+
return ref;
|
| 1272 |
+
}
|
| 1273 |
+
|
| 1274 |
+
/// Returns a reference to an element with a given offset from the current reference
|
| 1275 |
+
CUTLASS_HOST_DEVICE
|
| 1276 |
+
ConstSubbyteReference operator-=(long long offset) const {
|
| 1277 |
+
|
| 1278 |
+
ConstSubbyteReference ref(ptr_, offset_);
|
| 1279 |
+
ref -= offset;
|
| 1280 |
+
|
| 1281 |
+
return ref;
|
| 1282 |
+
}
|
| 1283 |
+
|
| 1284 |
+
/// Computes the difference in elements between references
|
| 1285 |
+
CUTLASS_HOST_DEVICE
|
| 1286 |
+
ptrdiff_t operator-(ConstSubbyteReference ref) const {
|
| 1287 |
+
return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_);
|
| 1288 |
+
}
|
| 1289 |
+
|
| 1290 |
+
/// Explicit cast to int
|
| 1291 |
+
CUTLASS_HOST_DEVICE
|
| 1292 |
+
explicit operator int() const {
|
| 1293 |
+
return int(get());
|
| 1294 |
+
}
|
| 1295 |
+
|
| 1296 |
+
/// Explicit cast to signed 64-bit integer
|
| 1297 |
+
CUTLASS_HOST_DEVICE
|
| 1298 |
+
explicit operator int64_t() const {
|
| 1299 |
+
return int64_t(get());
|
| 1300 |
+
}
|
| 1301 |
+
|
| 1302 |
+
/// Explicit cast to unsigned 64-bit integer
|
| 1303 |
+
CUTLASS_HOST_DEVICE
|
| 1304 |
+
explicit operator uint64_t() const {
|
| 1305 |
+
return uint64_t(get());
|
| 1306 |
+
}
|
| 1307 |
+
|
| 1308 |
+
/// Explicit cast to float
|
| 1309 |
+
CUTLASS_HOST_DEVICE
|
| 1310 |
+
explicit operator float() const {
|
| 1311 |
+
return float(get());
|
| 1312 |
+
}
|
| 1313 |
+
|
| 1314 |
+
/// Explicit cast to double
|
| 1315 |
+
CUTLASS_HOST_DEVICE
|
| 1316 |
+
explicit operator double() const {
|
| 1317 |
+
return double(get());
|
| 1318 |
+
}
|
| 1319 |
+
};
|
| 1320 |
+
|
| 1321 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1322 |
+
|
| 1323 |
+
template <typename Element, bool subbyte = (sizeof_bits<Element>::value < 8)>
|
| 1324 |
+
struct ReferenceFactory;
|
| 1325 |
+
|
| 1326 |
+
template <typename Element>
|
| 1327 |
+
struct ReferenceFactory<Element, false> {
|
| 1328 |
+
|
| 1329 |
+
///! Number of elements per storage vector
|
| 1330 |
+
static int const kElementsPerVector = 1;
|
| 1331 |
+
|
| 1332 |
+
CUTLASS_HOST_DEVICE
|
| 1333 |
+
static Element &get(Element *ptr, int64_t offset) {
|
| 1334 |
+
return ptr[offset];
|
| 1335 |
+
}
|
| 1336 |
+
|
| 1337 |
+
CUTLASS_HOST_DEVICE
|
| 1338 |
+
static Element const &get(Element const *ptr, int64_t offset) {
|
| 1339 |
+
return ptr[offset];
|
| 1340 |
+
}
|
| 1341 |
+
|
| 1342 |
+
CUTLASS_HOST_DEVICE
|
| 1343 |
+
static Element *add_pointer_offset(Element *ptr, int64_t offset) {
|
| 1344 |
+
return ptr + offset;
|
| 1345 |
+
}
|
| 1346 |
+
|
| 1347 |
+
CUTLASS_HOST_DEVICE
|
| 1348 |
+
static Element const *add_pointer_offset(Element const *ptr, int64_t offset) {
|
| 1349 |
+
return ptr + offset;
|
| 1350 |
+
}
|
| 1351 |
+
};
|
| 1352 |
+
|
| 1353 |
+
template <typename Element>
|
| 1354 |
+
struct ReferenceFactory<Element, true> {
|
| 1355 |
+
|
| 1356 |
+
//
|
| 1357 |
+
// Static methods
|
| 1358 |
+
//
|
| 1359 |
+
|
| 1360 |
+
CUTLASS_HOST_DEVICE
|
| 1361 |
+
static SubbyteReference<Element> get(Element *ptr, int64_t offset) {
|
| 1362 |
+
return SubbyteReference<Element>(ptr, offset);
|
| 1363 |
+
}
|
| 1364 |
+
|
| 1365 |
+
CUTLASS_HOST_DEVICE
|
| 1366 |
+
static ConstSubbyteReference<Element> get(Element const *ptr,
|
| 1367 |
+
int64_t offset) {
|
| 1368 |
+
return ConstSubbyteReference<Element>(ptr, offset);
|
| 1369 |
+
}
|
| 1370 |
+
|
| 1371 |
+
/// Helper to add an offset in number of elements, assuming this offset is divisible
|
| 1372 |
+
/// by the vector size.
|
| 1373 |
+
CUTLASS_HOST_DEVICE
|
| 1374 |
+
static Element *add_pointer_offset(Element *ptr, int64_t offset_in_elements) {
|
| 1375 |
+
return &SubbyteReference<Element>(ptr, offset_in_elements);
|
| 1376 |
+
}
|
| 1377 |
+
|
| 1378 |
+
/// Helper to add an offset in number of elements, assuming this offset is divisible
|
| 1379 |
+
/// by the vector size.
|
| 1380 |
+
CUTLASS_HOST_DEVICE
|
| 1381 |
+
static Element const *add_pointer_offset(Element const *ptr, int64_t offset_in_elements) {
|
| 1382 |
+
return &ConstSubbyteReference<Element>(ptr, offset_in_elements);
|
| 1383 |
+
}
|
| 1384 |
+
};
|
| 1385 |
+
|
| 1386 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1387 |
+
|
| 1388 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_coord.h
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Defines a canonical coordinate for rank=4 tensors offering named indices.
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
#include "cutlass/cutlass.h"
|
| 37 |
+
#include "cutlass/coord.h"
|
| 38 |
+
|
| 39 |
+
namespace cutlass {
|
| 40 |
+
|
| 41 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 42 |
+
|
| 43 |
+
/// Defines a canonical 4D coordinate used by tensor operations.
|
| 44 |
+
struct Tensor4DCoord : public Coord<4> {
|
| 45 |
+
|
| 46 |
+
/// Base class
|
| 47 |
+
using Base = Coord<4>;
|
| 48 |
+
|
| 49 |
+
/// Index type
|
| 50 |
+
using Index = typename Base::Index;
|
| 51 |
+
|
| 52 |
+
/// LongIndex type
|
| 53 |
+
using LongIndex = typename Base::LongIndex;
|
| 54 |
+
|
| 55 |
+
/// Batch dimension
|
| 56 |
+
static int const kN = 0;
|
| 57 |
+
|
| 58 |
+
/// Height dimension
|
| 59 |
+
static int const kH = 1;
|
| 60 |
+
|
| 61 |
+
/// Width dimension
|
| 62 |
+
static int const kW = 2;
|
| 63 |
+
|
| 64 |
+
/// Channels dimension
|
| 65 |
+
static int const kC = 3;
|
| 66 |
+
|
| 67 |
+
//
|
| 68 |
+
// Methods
|
| 69 |
+
//
|
| 70 |
+
|
| 71 |
+
/// Default ctor
|
| 72 |
+
CUTLASS_HOST_DEVICE
|
| 73 |
+
Tensor4DCoord() { }
|
| 74 |
+
|
| 75 |
+
/// Constructs from Coord<4>
|
| 76 |
+
CUTLASS_HOST_DEVICE
|
| 77 |
+
Tensor4DCoord(Coord<4> const &coord): Base(coord) { }
|
| 78 |
+
|
| 79 |
+
/// Helper to construct from N, H, W, and C.
|
| 80 |
+
CUTLASS_HOST_DEVICE
|
| 81 |
+
Tensor4DCoord(Index n, Index h, Index w, Index c): Base(make_Coord(n, h, w, c)) { }
|
| 82 |
+
|
| 83 |
+
/// Helper to construct from N, H, W, and C, which are LongIndex type
|
| 84 |
+
CUTLASS_HOST_DEVICE
|
| 85 |
+
Tensor4DCoord(LongIndex n, LongIndex h, LongIndex w, LongIndex c)
|
| 86 |
+
: Base(make_Coord(Index(n), Index(h), Index(w), Index(c))) { }
|
| 87 |
+
|
| 88 |
+
/// Returns the batch of the coordinate
|
| 89 |
+
CUTLASS_HOST_DEVICE
|
| 90 |
+
Index const & n() const { return this->at(kN); }
|
| 91 |
+
|
| 92 |
+
/// Returns the batch of the coordinate
|
| 93 |
+
CUTLASS_HOST_DEVICE
|
| 94 |
+
Index & n() { return this->at(kN); }
|
| 95 |
+
|
| 96 |
+
/// Returns the row of the coordinate
|
| 97 |
+
CUTLASS_HOST_DEVICE
|
| 98 |
+
Index const & h() const { return this->at(kH); }
|
| 99 |
+
|
| 100 |
+
/// Returns the row of the coordinate
|
| 101 |
+
CUTLASS_HOST_DEVICE
|
| 102 |
+
Index & h() { return this->at(kH); }
|
| 103 |
+
|
| 104 |
+
/// Returns the column of the coordinate
|
| 105 |
+
CUTLASS_HOST_DEVICE
|
| 106 |
+
Index const & w() const { return this->at(kW); }
|
| 107 |
+
|
| 108 |
+
/// Returns the column of the coordinate
|
| 109 |
+
CUTLASS_HOST_DEVICE
|
| 110 |
+
Index & w() { return this->at(kW); }
|
| 111 |
+
|
| 112 |
+
/// Returns the channel of the coordinate
|
| 113 |
+
CUTLASS_HOST_DEVICE
|
| 114 |
+
Index const & c() const { return this->at(kC); }
|
| 115 |
+
|
| 116 |
+
/// Returns the channel of the coordinate
|
| 117 |
+
CUTLASS_HOST_DEVICE
|
| 118 |
+
Index & c() { return this->at(kC); }
|
| 119 |
+
|
| 120 |
+
//
|
| 121 |
+
// Coord operators
|
| 122 |
+
//
|
| 123 |
+
|
| 124 |
+
/// Element-wise addition
|
| 125 |
+
CUTLASS_HOST_DEVICE
|
| 126 |
+
Tensor4DCoord operator+(Base const& b) const {
|
| 127 |
+
return Tensor4DCoord(Base::operator+(b));
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
/// Element-wise subtraction
|
| 131 |
+
CUTLASS_HOST_DEVICE
|
| 132 |
+
Tensor4DCoord operator-(Base const& b) const {
|
| 133 |
+
return Tensor4DCoord(Base::operator-(b));
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
/// Element-wise multiplication
|
| 137 |
+
CUTLASS_HOST_DEVICE
|
| 138 |
+
Tensor4DCoord operator*(Base const& b) const {
|
| 139 |
+
return Tensor4DCoord(Base::operator*(b));
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
/// Element-wise division
|
| 143 |
+
CUTLASS_HOST_DEVICE
|
| 144 |
+
Tensor4DCoord operator/(Base const& b) const {
|
| 145 |
+
return Tensor4DCoord(Base::operator/(b));
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
/// In-place addition
|
| 149 |
+
CUTLASS_HOST_DEVICE
|
| 150 |
+
Tensor4DCoord& operator+=(Base const& b) {
|
| 151 |
+
Base::operator+=(b);
|
| 152 |
+
return *this;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
/// In-place subtraction
|
| 156 |
+
CUTLASS_HOST_DEVICE
|
| 157 |
+
Tensor4DCoord& operator-=(Base const& b) {
|
| 158 |
+
Base::operator-=(b);
|
| 159 |
+
return *this;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
/// In-place multiplication
|
| 163 |
+
CUTLASS_HOST_DEVICE
|
| 164 |
+
Tensor4DCoord& operator*=(Base const& b) {
|
| 165 |
+
Base::operator*=(b);
|
| 166 |
+
return *this;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
/// In-place division
|
| 170 |
+
CUTLASS_HOST_DEVICE
|
| 171 |
+
Tensor4DCoord& operator/=(Base const& b) {
|
| 172 |
+
Base::operator/=(b);
|
| 173 |
+
return *this;
|
| 174 |
+
}
|
| 175 |
+
};
|
| 176 |
+
|
| 177 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 178 |
+
|
| 179 |
+
/// Defines a canonical 5D coordinate used by tensor operations.
|
| 180 |
+
struct Tensor5DCoord : public Coord<5> {
|
| 181 |
+
|
| 182 |
+
/// Base class
|
| 183 |
+
using Base = Coord<5>;
|
| 184 |
+
|
| 185 |
+
/// Index type
|
| 186 |
+
using Index = typename Base::Index;
|
| 187 |
+
|
| 188 |
+
/// LongIndex type
|
| 189 |
+
using LongIndex = typename Base::LongIndex;
|
| 190 |
+
|
| 191 |
+
/// Batch dimension
|
| 192 |
+
static int const kN = 0;
|
| 193 |
+
|
| 194 |
+
/// Depth dimension
|
| 195 |
+
static int const kD = 1;
|
| 196 |
+
|
| 197 |
+
/// Height dimension
|
| 198 |
+
static int const kH = 2;
|
| 199 |
+
|
| 200 |
+
/// Width dimension
|
| 201 |
+
static int const kW = 3;
|
| 202 |
+
|
| 203 |
+
/// Channels dimension
|
| 204 |
+
static int const kC = 4;
|
| 205 |
+
|
| 206 |
+
//
|
| 207 |
+
// Methods
|
| 208 |
+
//
|
| 209 |
+
|
| 210 |
+
/// Default ctor
|
| 211 |
+
CUTLASS_HOST_DEVICE
|
| 212 |
+
Tensor5DCoord() { }
|
| 213 |
+
|
| 214 |
+
/// Constructs from Coord<5>
|
| 215 |
+
CUTLASS_HOST_DEVICE
|
| 216 |
+
Tensor5DCoord(Coord<5> const &coord): Base(coord) { }
|
| 217 |
+
|
| 218 |
+
/// Helper to construct from N, D, H, W, and C.
|
| 219 |
+
CUTLASS_HOST_DEVICE
|
| 220 |
+
Tensor5DCoord(Index n, Index d, Index h, Index w, Index c): Base(make_Coord(n, d, h, w, c)) { }
|
| 221 |
+
|
| 222 |
+
/// Helper to construct from N, D, H, W, and C, which are LongIndex type
|
| 223 |
+
CUTLASS_HOST_DEVICE
|
| 224 |
+
Tensor5DCoord(LongIndex n, LongIndex d, LongIndex h, LongIndex w, LongIndex c)
|
| 225 |
+
: Base(make_Coord(Index(n), Index(d), Index(h), Index(w), Index(c))) { }
|
| 226 |
+
|
| 227 |
+
/// Returns the batch of the coordinate
|
| 228 |
+
CUTLASS_HOST_DEVICE
|
| 229 |
+
Index const & n() const { return this->at(kN); }
|
| 230 |
+
|
| 231 |
+
/// Returns the batch of the coordinate
|
| 232 |
+
CUTLASS_HOST_DEVICE
|
| 233 |
+
Index & n() { return this->at(kN); }
|
| 234 |
+
|
| 235 |
+
/// Returns the batch of the coordinate
|
| 236 |
+
CUTLASS_HOST_DEVICE
|
| 237 |
+
Index const & d() const { return this->at(kD); }
|
| 238 |
+
|
| 239 |
+
/// Returns the batch of the coordinate
|
| 240 |
+
CUTLASS_HOST_DEVICE
|
| 241 |
+
Index & d() { return this->at(kD); }
|
| 242 |
+
|
| 243 |
+
/// Returns the row of the coordinate
|
| 244 |
+
CUTLASS_HOST_DEVICE
|
| 245 |
+
Index const & h() const { return this->at(kH); }
|
| 246 |
+
|
| 247 |
+
/// Returns the row of the coordinate
|
| 248 |
+
CUTLASS_HOST_DEVICE
|
| 249 |
+
Index & h() { return this->at(kH); }
|
| 250 |
+
|
| 251 |
+
/// Returns the column of the coordinate
|
| 252 |
+
CUTLASS_HOST_DEVICE
|
| 253 |
+
Index const & w() const { return this->at(kW); }
|
| 254 |
+
|
| 255 |
+
/// Returns the column of the coordinate
|
| 256 |
+
CUTLASS_HOST_DEVICE
|
| 257 |
+
Index & w() { return this->at(kW); }
|
| 258 |
+
|
| 259 |
+
/// Returns the channel of the coordinate
|
| 260 |
+
CUTLASS_HOST_DEVICE
|
| 261 |
+
Index const & c() const { return this->at(kC); }
|
| 262 |
+
|
| 263 |
+
/// Returns the channel of the coordinate
|
| 264 |
+
CUTLASS_HOST_DEVICE
|
| 265 |
+
Index & c() { return this->at(kC); }
|
| 266 |
+
|
| 267 |
+
//
|
| 268 |
+
// Coord operators
|
| 269 |
+
//
|
| 270 |
+
|
| 271 |
+
/// Element-wise addition
|
| 272 |
+
CUTLASS_HOST_DEVICE
|
| 273 |
+
Tensor5DCoord operator+(Base const& b) const {
|
| 274 |
+
return Tensor5DCoord(Base::operator+(b));
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
/// Element-wise subtraction
|
| 278 |
+
CUTLASS_HOST_DEVICE
|
| 279 |
+
Tensor5DCoord operator-(Base const& b) const {
|
| 280 |
+
return Tensor5DCoord(Base::operator-(b));
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
/// Element-wise multiplication
|
| 284 |
+
CUTLASS_HOST_DEVICE
|
| 285 |
+
Tensor5DCoord operator*(Base const& b) const {
|
| 286 |
+
return Tensor5DCoord(Base::operator*(b));
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
/// Element-wise division
|
| 290 |
+
CUTLASS_HOST_DEVICE
|
| 291 |
+
Tensor5DCoord operator/(Base const& b) const {
|
| 292 |
+
return Tensor5DCoord(Base::operator/(b));
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
/// In-place addition
|
| 296 |
+
CUTLASS_HOST_DEVICE
|
| 297 |
+
Tensor5DCoord& operator+=(Base const& b) {
|
| 298 |
+
Base::operator+=(b);
|
| 299 |
+
return *this;
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
/// In-place subtraction
|
| 303 |
+
CUTLASS_HOST_DEVICE
|
| 304 |
+
Tensor5DCoord& operator-=(Base const& b) {
|
| 305 |
+
Base::operator-=(b);
|
| 306 |
+
return *this;
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
/// In-place multiplication
|
| 310 |
+
CUTLASS_HOST_DEVICE
|
| 311 |
+
Tensor5DCoord& operator*=(Base const& b) {
|
| 312 |
+
Base::operator*=(b);
|
| 313 |
+
return *this;
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
/// In-place division
|
| 317 |
+
CUTLASS_HOST_DEVICE
|
| 318 |
+
Tensor5DCoord& operator/=(Base const& b) {
|
| 319 |
+
Base::operator/=(b);
|
| 320 |
+
return *this;
|
| 321 |
+
}
|
| 322 |
+
};
|
| 323 |
+
|
| 324 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 325 |
+
|
| 326 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_ref.h
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Defines a structure containing strides, bounds, and a pointer to tensor data.
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/coord.h"
|
| 39 |
+
#include "cutlass/platform/platform.h"
|
| 40 |
+
#include "cutlass/subbyte_reference.h"
|
| 41 |
+
|
| 42 |
+
namespace cutlass {
|
| 43 |
+
|
| 44 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
|
| 46 |
+
/// Default layout function from coordinates in a tensor's index space into the n-D array held
|
| 47 |
+
/// in memory.
|
| 48 |
+
///
|
| 49 |
+
/// All layout functions must define at least the members shown in IdentityTensorLayout<>.
|
| 50 |
+
template <int Rank>
|
| 51 |
+
class IdentityTensorLayout {
|
| 52 |
+
public:
|
| 53 |
+
/// Logical rank of tensor
|
| 54 |
+
static int const kRank = Rank;
|
| 55 |
+
|
| 56 |
+
/// Rank of stride vector
|
| 57 |
+
static int const kStrideRank = Rank;
|
| 58 |
+
|
| 59 |
+
/// Index type used for coordinates
|
| 60 |
+
using Index = int32_t;
|
| 61 |
+
|
| 62 |
+
/// Long index type used for offsets
|
| 63 |
+
using LongIndex = int64_t;
|
| 64 |
+
|
| 65 |
+
/// Logical coordinate
|
| 66 |
+
using TensorCoord = Coord<kRank, Index>;
|
| 67 |
+
|
| 68 |
+
/// Stride vector
|
| 69 |
+
using Stride = Coord<kStrideRank, Index>;
|
| 70 |
+
|
| 71 |
+
private:
|
| 72 |
+
|
| 73 |
+
//
|
| 74 |
+
// Data members
|
| 75 |
+
//
|
| 76 |
+
|
| 77 |
+
/// Stride data member
|
| 78 |
+
Stride stride_;
|
| 79 |
+
|
| 80 |
+
public:
|
| 81 |
+
|
| 82 |
+
//
|
| 83 |
+
// Methods
|
| 84 |
+
//
|
| 85 |
+
|
| 86 |
+
CUTLASS_HOST_DEVICE
|
| 87 |
+
IdentityTensorLayout(Stride const &stride = Stride()): stride_(stride) { }
|
| 88 |
+
|
| 89 |
+
/// Returns the offset of a coordinate in linear memory
|
| 90 |
+
CUTLASS_HOST_DEVICE
|
| 91 |
+
LongIndex operator()(Coord<Rank> const &coord) const {
|
| 92 |
+
return coord.dot(stride_);
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
/// Returns the stride of the layout
|
| 96 |
+
CUTLASS_HOST_DEVICE
|
| 97 |
+
Stride stride() const {
|
| 98 |
+
return stride_;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
/// Returns the stride of the layout
|
| 102 |
+
CUTLASS_HOST_DEVICE
|
| 103 |
+
Stride & stride() {
|
| 104 |
+
return stride_;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
/// Compute the number of contiguous elements needed to store a tensor with the given size
|
| 108 |
+
CUTLASS_HOST_DEVICE
|
| 109 |
+
LongIndex capacity(TensorCoord const &size) const {
|
| 110 |
+
int idx = stride_.max_dim_index();
|
| 111 |
+
return stride_[idx] * size[idx];
|
| 112 |
+
}
|
| 113 |
+
};
|
| 114 |
+
|
| 115 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 116 |
+
|
| 117 |
+
/* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank
|
| 118 |
+
and layout within memory. A TensorRef combines a pointer and a Layout concept
|
| 119 |
+
|
| 120 |
+
Examples:
|
| 121 |
+
|
| 122 |
+
(These examples use helpers for matrix layouts defined in cutlass/layout/matrix.h)
|
| 123 |
+
|
| 124 |
+
1. Column-major matrix may be represented as a rank=2 tensor:
|
| 125 |
+
|
| 126 |
+
TensorRef<float, layout::ColumnMajor> A(ptr_A, ldm);
|
| 127 |
+
|
| 128 |
+
2. Row-major matrix may be represented as a rank=2 tensor:
|
| 129 |
+
|
| 130 |
+
TensorRef<float, layout::RowMajor> B(ptr_A, ldm);
|
| 131 |
+
|
| 132 |
+
3. An interleaved matrix may be represented as a rank=2 tensor:
|
| 133 |
+
|
| 134 |
+
TensorRef<int8_t, layout::ColumnMajorInterleaved<32> > C;
|
| 135 |
+
|
| 136 |
+
4. A helper exists to define a TensorRef for a contiguous matrix whose layout
|
| 137 |
+
is not known at compile time.
|
| 138 |
+
|
| 139 |
+
int ldm; // leading dimension
|
| 140 |
+
layout::Matrix kind; // Could be layout::Matrix::kRowMajor or layout::Matrix::kColumnMajor
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
TensorRef<int, layout::ContiguousMatrix> E(ptr_E, {ldm, kind});
|
| 144 |
+
|
| 145 |
+
*/
|
| 146 |
+
template <
|
| 147 |
+
/// Data type of element stored within tensor (concept: NumericType)
|
| 148 |
+
typename Element_,
|
| 149 |
+
/// Defines a mapping from logical coordinate to linear memory (concept: Layout)
|
| 150 |
+
typename Layout_
|
| 151 |
+
>
|
| 152 |
+
class TensorRef {
|
| 153 |
+
public:
|
| 154 |
+
/// Data type of individual access
|
| 155 |
+
using Element = Element_;
|
| 156 |
+
|
| 157 |
+
/// Mapping function from logical coordinate to linear memory
|
| 158 |
+
using Layout = Layout_;
|
| 159 |
+
|
| 160 |
+
/// Reference type to an element
|
| 161 |
+
using Reference = typename platform::conditional<
|
| 162 |
+
sizeof_bits<Element>::value >= 8,
|
| 163 |
+
Element &,
|
| 164 |
+
SubbyteReference<Element>
|
| 165 |
+
>::type;
|
| 166 |
+
|
| 167 |
+
/// Logical rank of tensor index space
|
| 168 |
+
static int const kRank = Layout::kRank;
|
| 169 |
+
|
| 170 |
+
/// Index type
|
| 171 |
+
using Index = typename Layout::Index;
|
| 172 |
+
|
| 173 |
+
/// Long index used for pointer offsets
|
| 174 |
+
using LongIndex = typename Layout::LongIndex;
|
| 175 |
+
|
| 176 |
+
/// Coordinate in logical tensor space
|
| 177 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 178 |
+
|
| 179 |
+
/// Layout's stride vector
|
| 180 |
+
using Stride = typename Layout::Stride;
|
| 181 |
+
|
| 182 |
+
/// TensorRef to constant data
|
| 183 |
+
using ConstTensorRef = TensorRef<
|
| 184 |
+
typename platform::remove_const<Element>::type const,
|
| 185 |
+
Layout>;
|
| 186 |
+
|
| 187 |
+
/// TensorRef to non-constant data
|
| 188 |
+
using NonConstTensorRef = TensorRef<
|
| 189 |
+
typename platform::remove_const<Element>::type,
|
| 190 |
+
Layout>;
|
| 191 |
+
|
| 192 |
+
/// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a
|
| 193 |
+
/// scalar, but degenerate cases such as these are difficult to accommodate without
|
| 194 |
+
/// extensive C++ metaprogramming or support for zero-length arrays.
|
| 195 |
+
static_assert(kRank > 0, "Cannot define a zero-rank TensorRef");
|
| 196 |
+
|
| 197 |
+
private:
|
| 198 |
+
|
| 199 |
+
/// Pointer
|
| 200 |
+
Element* ptr_;
|
| 201 |
+
|
| 202 |
+
/// Layout object maps logical coordinates to linear offsets
|
| 203 |
+
Layout layout_;
|
| 204 |
+
|
| 205 |
+
public:
|
| 206 |
+
|
| 207 |
+
//
|
| 208 |
+
// Methods
|
| 209 |
+
//
|
| 210 |
+
|
| 211 |
+
/// Constructs a TensorRef with a pointer and layout object.
|
| 212 |
+
CUTLASS_HOST_DEVICE
|
| 213 |
+
TensorRef(): ptr_(nullptr) {
|
| 214 |
+
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
/// Constructs a TensorRef with a pointer and layout object.
|
| 218 |
+
CUTLASS_HOST_DEVICE
|
| 219 |
+
TensorRef(
|
| 220 |
+
Element *ptr, ///< pointer to start of tensor
|
| 221 |
+
Layout const &layout ///< layout object containing stride and mapping function
|
| 222 |
+
):
|
| 223 |
+
ptr_(ptr), layout_(layout) {
|
| 224 |
+
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
/// Converting constructor from TensorRef to non-constant data.
|
| 228 |
+
template<typename _Magic = int>
|
| 229 |
+
CUTLASS_HOST_DEVICE
|
| 230 |
+
TensorRef(
|
| 231 |
+
NonConstTensorRef const &ref, ///< TensorRef to non-const data
|
| 232 |
+
///SFINAE trick to avoid creating a copy-constructor when Element_ is already non-const
|
| 233 |
+
_Magic magic = (typename platform::enable_if< ! platform::is_same<NonConstTensorRef, TensorRef<Element_, Layout_> >::value, _Magic>::type)0
|
| 234 |
+
):
|
| 235 |
+
ptr_(ref.data()), layout_(ref.layout()) { }
|
| 236 |
+
|
| 237 |
+
/// Returns a reference to constant-valued tensor.
|
| 238 |
+
CUTLASS_HOST_DEVICE
|
| 239 |
+
ConstTensorRef const_ref() const {
|
| 240 |
+
return ConstTensorRef(ptr_, layout_);
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
CUTLASS_HOST_DEVICE
|
| 244 |
+
NonConstTensorRef non_const_ref() const {
|
| 245 |
+
return NonConstTensorRef(const_cast<typename platform::remove_const<Element>::type *>(ptr_), layout_);
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
/// Updates only the pointer
|
| 249 |
+
CUTLASS_HOST_DEVICE
|
| 250 |
+
void reset(Element* ptr = nullptr) {
|
| 251 |
+
ptr_ = ptr;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
/// Updates the pointer and layout object
|
| 255 |
+
CUTLASS_HOST_DEVICE
|
| 256 |
+
void reset(Element* ptr, Layout const &layout) {
|
| 257 |
+
ptr_ = ptr;
|
| 258 |
+
layout_ = layout;
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
/// Returns true if the TensorRef is non-null
|
| 262 |
+
CUTLASS_HOST_DEVICE
|
| 263 |
+
bool good() const {
|
| 264 |
+
return ptr_ != nullptr;
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
/// Returns the pointer to referenced data
|
| 268 |
+
CUTLASS_HOST_DEVICE
|
| 269 |
+
Element * data() const { return ptr_; }
|
| 270 |
+
|
| 271 |
+
/// Returns a reference to the element at a given linear index
|
| 272 |
+
CUTLASS_HOST_DEVICE
|
| 273 |
+
Reference data(LongIndex idx) const {
|
| 274 |
+
return ReferenceFactory<typename platform::remove_const<Element>::type,
|
| 275 |
+
(sizeof_bits<Element>::value < 8)>::get(ptr_, idx);
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
/// Returns the layout object
|
| 279 |
+
CUTLASS_HOST_DEVICE
|
| 280 |
+
Layout & layout() {
|
| 281 |
+
return layout_;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
/// Returns the layout object
|
| 285 |
+
CUTLASS_HOST_DEVICE
|
| 286 |
+
Layout layout() const {
|
| 287 |
+
return layout_;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
/// Returns the layout object's stride vector
|
| 291 |
+
CUTLASS_HOST_DEVICE
|
| 292 |
+
Stride stride() const {
|
| 293 |
+
return layout_.stride();
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
/// Returns the layout object's stride vector
|
| 297 |
+
CUTLASS_HOST_DEVICE
|
| 298 |
+
Stride & stride() {
|
| 299 |
+
return layout_.stride();
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
/// Returns the layout object's stride in a given physical dimension
|
| 303 |
+
CUTLASS_HOST_DEVICE
|
| 304 |
+
typename Layout::Stride::Index stride(int dim) const {
|
| 305 |
+
return layout_.stride().at(dim);
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
/// Returns the layout object's stride in a given physical dimension
|
| 309 |
+
CUTLASS_HOST_DEVICE
|
| 310 |
+
typename Layout::Stride::Index & stride(int dim) {
|
| 311 |
+
return layout_.stride().at(dim);
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
/// Computes the offset of an index from the origin of the tensor
|
| 315 |
+
CUTLASS_HOST_DEVICE
|
| 316 |
+
LongIndex offset(TensorCoord const& coord) const {
|
| 317 |
+
return layout_(coord);
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
/// Returns a reference to the element at a given Coord
|
| 321 |
+
CUTLASS_HOST_DEVICE
|
| 322 |
+
Reference at(TensorCoord const& coord) const {
|
| 323 |
+
return data(offset(coord));
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
/// Returns a reference to the element at a given Coord
|
| 327 |
+
CUTLASS_HOST_DEVICE
|
| 328 |
+
Reference operator[](TensorCoord const& coord) const {
|
| 329 |
+
return data(offset(coord));
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
/// Adds an offset to each pointer
|
| 333 |
+
CUTLASS_HOST_DEVICE
|
| 334 |
+
TensorRef & add_pointer_offset(LongIndex offset_) {
|
| 335 |
+
ptr_ = ReferenceFactory<typename platform::remove_const<Element>::type,
|
| 336 |
+
(sizeof_bits<Element>::value < 8)>::add_pointer_offset(ptr_, offset_);
|
| 337 |
+
return *this;
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
/// Adds an offset to each pointer
|
| 341 |
+
CUTLASS_HOST_DEVICE
|
| 342 |
+
TensorRef & add_coord_offset(TensorCoord const &coord) {
|
| 343 |
+
add_pointer_offset(offset(coord));
|
| 344 |
+
return *this;
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
/// Returns a TensorRef offset by a given amount
|
| 348 |
+
CUTLASS_HOST_DEVICE
|
| 349 |
+
TensorRef operator+(TensorCoord const& b) const {
|
| 350 |
+
TensorRef result(*this);
|
| 351 |
+
result.add_coord_offset(b);
|
| 352 |
+
return result;
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
/// Returns a TensorRef offset by a given amount
|
| 356 |
+
CUTLASS_HOST_DEVICE
|
| 357 |
+
TensorRef & operator+=(TensorCoord const& b) {
|
| 358 |
+
add_coord_offset(b);
|
| 359 |
+
return *this;
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
/// Returns a TensorRef offset by a given amount
|
| 363 |
+
CUTLASS_HOST_DEVICE
|
| 364 |
+
TensorRef operator-(TensorCoord const& b) const {
|
| 365 |
+
TensorRef result(*this);
|
| 366 |
+
result.add_pointer_offset(-offset(b));
|
| 367 |
+
return result;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
/// Returns a TensorRef offset by a given amount
|
| 371 |
+
CUTLASS_HOST_DEVICE
|
| 372 |
+
TensorRef & operator-=(TensorCoord const& b) {
|
| 373 |
+
add_pointer_offset(-offset(b));
|
| 374 |
+
return *this;
|
| 375 |
+
}
|
| 376 |
+
};
|
| 377 |
+
|
| 378 |
+
/// Constructs a TensorRef, deducing types from arguments.
|
| 379 |
+
template <
|
| 380 |
+
typename Element,
|
| 381 |
+
typename Layout
|
| 382 |
+
>
|
| 383 |
+
CUTLASS_HOST_DEVICE
|
| 384 |
+
TensorRef<Element, Layout> make_TensorRef(Element *ptr, Layout const &layout) {
|
| 385 |
+
return TensorRef<Element, Layout>(ptr, layout);
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 389 |
+
//
|
| 390 |
+
// Partial specializations to handle degenerate and sub-byte cases.
|
| 391 |
+
//
|
| 392 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 393 |
+
|
| 394 |
+
template <
|
| 395 |
+
typename Element,
|
| 396 |
+
typename Layout
|
| 397 |
+
>
|
| 398 |
+
CUTLASS_HOST_DEVICE
|
| 399 |
+
bool TensorRef_aligned(TensorRef<Element, Layout> const &ref, int alignment) {
|
| 400 |
+
|
| 401 |
+
int const kStrideRank = Layout::kStrideRank;
|
| 402 |
+
|
| 403 |
+
if (reinterpret_cast<uintptr_t>(ref.data()) % alignment) {
|
| 404 |
+
return false;
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
CUTLASS_PRAGMA_UNROLL
|
| 408 |
+
for (int i = 0; i < kStrideRank; ++i) {
|
| 409 |
+
if (ref.stride(i) % alignment) {
|
| 410 |
+
return false;
|
| 411 |
+
}
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
return true;
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 418 |
+
|
| 419 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_ref_planar_complex.h
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Defines a structure containing strides, bounds, and a pointer to tensor data.
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
#include <cstdint>
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/complex.h"
|
| 39 |
+
#include "cutlass/tensor_ref.h"
|
| 40 |
+
|
| 41 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 42 |
+
|
| 43 |
+
namespace cutlass {
|
| 44 |
+
|
| 45 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 46 |
+
|
| 47 |
+
template <typename Element_>
|
| 48 |
+
struct PlanarComplexReference {
|
| 49 |
+
|
| 50 |
+
//
|
| 51 |
+
// Type definitions
|
| 52 |
+
//
|
| 53 |
+
|
| 54 |
+
using Element = Element_;
|
| 55 |
+
using ComplexElement = complex<Element>;
|
| 56 |
+
|
| 57 |
+
//
|
| 58 |
+
// Data members
|
| 59 |
+
//
|
| 60 |
+
|
| 61 |
+
Element *real;
|
| 62 |
+
Element *imag;
|
| 63 |
+
|
| 64 |
+
//
|
| 65 |
+
// Methods
|
| 66 |
+
//
|
| 67 |
+
|
| 68 |
+
CUTLASS_HOST_DEVICE
|
| 69 |
+
PlanarComplexReference(
|
| 70 |
+
Element *real_ = nullptr,
|
| 71 |
+
Element *imag_ = nullptr
|
| 72 |
+
):
|
| 73 |
+
real(real_), imag(imag_) { }
|
| 74 |
+
|
| 75 |
+
/// Loads the complex element
|
| 76 |
+
CUTLASS_HOST_DEVICE
|
| 77 |
+
operator complex<Element>() const {
|
| 78 |
+
return complex<Element>{*real, *imag};
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
/// Stores a complex element to the location pointed to by the reference
|
| 82 |
+
CUTLASS_HOST_DEVICE
|
| 83 |
+
PlanarComplexReference &operator=(complex<Element> const &rhs) {
|
| 84 |
+
*real = rhs.real();
|
| 85 |
+
*imag = rhs.imag();
|
| 86 |
+
return *this;
|
| 87 |
+
}
|
| 88 |
+
};
|
| 89 |
+
|
| 90 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 91 |
+
|
| 92 |
+
/* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank
|
| 93 |
+
and layout within memory. A TensorRef combines a pointer and a Layout concept
|
| 94 |
+
|
| 95 |
+
*/
|
| 96 |
+
template <
|
| 97 |
+
/// Data type of element stored within tensor (concept: NumericType)
|
| 98 |
+
typename Element_,
|
| 99 |
+
/// Defines a mapping from logical coordinate to linear memory (concept: Layout)
|
| 100 |
+
typename Layout_
|
| 101 |
+
>
|
| 102 |
+
class TensorRefPlanarComplex {
|
| 103 |
+
public:
|
| 104 |
+
/// Data type of individual access
|
| 105 |
+
using Element = Element_;
|
| 106 |
+
|
| 107 |
+
/// Complex element type
|
| 108 |
+
using ComplexElement = complex<Element>;
|
| 109 |
+
|
| 110 |
+
/// Mapping function from logical coordinate to linear memory
|
| 111 |
+
using Layout = Layout_;
|
| 112 |
+
|
| 113 |
+
static_assert(sizeof_bits<Element>::value >= 8,
|
| 114 |
+
"Planar complex not suitable for subbyte elements at this time");
|
| 115 |
+
|
| 116 |
+
/// Reference type to an element
|
| 117 |
+
using Reference = PlanarComplexReference<Element>;
|
| 118 |
+
|
| 119 |
+
/// Logical rank of tensor index space
|
| 120 |
+
static int const kRank = Layout::kRank;
|
| 121 |
+
|
| 122 |
+
/// Index type
|
| 123 |
+
using Index = typename Layout::Index;
|
| 124 |
+
|
| 125 |
+
/// Long index used for pointer offsets
|
| 126 |
+
using LongIndex = typename Layout::LongIndex;
|
| 127 |
+
|
| 128 |
+
/// Coordinate in logical tensor space
|
| 129 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 130 |
+
|
| 131 |
+
/// Layout's stride vector
|
| 132 |
+
using Stride = typename Layout::Stride;
|
| 133 |
+
|
| 134 |
+
/// TensorRef to constant data
|
| 135 |
+
using ConstTensorRef = TensorRefPlanarComplex<
|
| 136 |
+
typename platform::remove_const<Element>::type const,
|
| 137 |
+
Layout>;
|
| 138 |
+
|
| 139 |
+
/// TensorRef to non-constant data
|
| 140 |
+
using NonConstTensorRef = TensorRefPlanarComplex<
|
| 141 |
+
typename platform::remove_const<Element>::type,
|
| 142 |
+
Layout>;
|
| 143 |
+
|
| 144 |
+
/// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a
|
| 145 |
+
/// scalar, but degenerate cases such as these are difficult to accommodate without
|
| 146 |
+
/// extensive C++ metaprogramming or support for zero-length arrays.
|
| 147 |
+
static_assert(kRank > 0, "Cannot define a zero-rank TensorRef");
|
| 148 |
+
|
| 149 |
+
private:
|
| 150 |
+
|
| 151 |
+
/// Pointer
|
| 152 |
+
Element* ptr_;
|
| 153 |
+
|
| 154 |
+
/// Layout object maps logical coordinates to linear offsets
|
| 155 |
+
Layout layout_;
|
| 156 |
+
|
| 157 |
+
/// Offset to imaginary part
|
| 158 |
+
LongIndex imaginary_stride_;
|
| 159 |
+
|
| 160 |
+
public:
|
| 161 |
+
|
| 162 |
+
//
|
| 163 |
+
// Methods
|
| 164 |
+
//
|
| 165 |
+
|
| 166 |
+
/// Constructs a TensorRef with a pointer and layout object.
|
| 167 |
+
CUTLASS_HOST_DEVICE
|
| 168 |
+
TensorRefPlanarComplex(
|
| 169 |
+
Element *ptr = nullptr, ///< pointer to start of tensor
|
| 170 |
+
Layout const &layout = Layout(), ///< layout object containing stride and mapping function
|
| 171 |
+
LongIndex imaginary_stride = 0
|
| 172 |
+
):
|
| 173 |
+
ptr_(ptr), layout_(layout), imaginary_stride_(imaginary_stride) {
|
| 174 |
+
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
/// Converting constructor from TensorRef to non-constant data.
|
| 178 |
+
CUTLASS_HOST_DEVICE
|
| 179 |
+
TensorRefPlanarComplex(
|
| 180 |
+
NonConstTensorRef const &ref ///< TensorRef to non-const data
|
| 181 |
+
):
|
| 182 |
+
ptr_(ref.data()), layout_(ref.layout()), imaginary_stride_(ref.imaginary_stride_) { }
|
| 183 |
+
|
| 184 |
+
/// Returns a reference to constant-valued tensor.
|
| 185 |
+
CUTLASS_HOST_DEVICE
|
| 186 |
+
ConstTensorRef const_ref() const {
|
| 187 |
+
return ConstTensorRef(ptr_, layout_, imaginary_stride_);
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
CUTLASS_HOST_DEVICE
|
| 191 |
+
NonConstTensorRef non_const_ref() const {
|
| 192 |
+
return NonConstTensorRef(
|
| 193 |
+
const_cast<typename platform::remove_const<Element>::type *>(ptr_),
|
| 194 |
+
layout_,
|
| 195 |
+
imaginary_stride_);
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
/// Updates only the pointer
|
| 199 |
+
CUTLASS_HOST_DEVICE
|
| 200 |
+
void reset(Element* ptr = nullptr, LongIndex imaginary_stride = 0) {
|
| 201 |
+
ptr_ = ptr;
|
| 202 |
+
imaginary_stride_ = imaginary_stride;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
/// Updates the pointer and layout object
|
| 206 |
+
CUTLASS_HOST_DEVICE
|
| 207 |
+
void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride) {
|
| 208 |
+
ptr_ = ptr;
|
| 209 |
+
layout_ = layout;
|
| 210 |
+
imaginary_stride_ = imaginary_stride;
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
/// Returns true if the TensorRef is non-null
|
| 214 |
+
CUTLASS_HOST_DEVICE
|
| 215 |
+
bool good() const {
|
| 216 |
+
return ptr_ != nullptr;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
/// Returns the pointer to referenced data
|
| 220 |
+
CUTLASS_HOST_DEVICE
|
| 221 |
+
Element * data() const { return ptr_; }
|
| 222 |
+
|
| 223 |
+
/// Returns the pointer to referenced data
|
| 224 |
+
CUTLASS_HOST_DEVICE
|
| 225 |
+
Element * imaginary_data() const { return ptr_ + imaginary_stride_; }
|
| 226 |
+
|
| 227 |
+
/// Returns a reference to the element at a given linear index
|
| 228 |
+
CUTLASS_HOST_DEVICE
|
| 229 |
+
Reference data(LongIndex idx) const {
|
| 230 |
+
return Reference(ptr_ + idx, ptr_ + idx + imaginary_stride_);
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
/// Returns the layout object
|
| 234 |
+
CUTLASS_HOST_DEVICE
|
| 235 |
+
Layout & layout() {
|
| 236 |
+
return layout_;
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
/// Returns the layout object
|
| 240 |
+
CUTLASS_HOST_DEVICE
|
| 241 |
+
Layout layout() const {
|
| 242 |
+
return layout_;
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
/// Gets the stride to an imaginary element
|
| 246 |
+
LongIndex imaginary_stride() const {
|
| 247 |
+
return imaginary_stride_;
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
/// Gets the stride to an imaginary element
|
| 251 |
+
LongIndex &imaginary_stride() {
|
| 252 |
+
return imaginary_stride_;
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
/// Returns the layout object's stride vector
|
| 256 |
+
CUTLASS_HOST_DEVICE
|
| 257 |
+
Stride stride() const {
|
| 258 |
+
return layout_.stride();
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
/// Returns the layout object's stride vector
|
| 262 |
+
CUTLASS_HOST_DEVICE
|
| 263 |
+
Stride & stride() {
|
| 264 |
+
return layout_.stride();
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
/// Returns the layout object's stride in a given physical dimension
|
| 268 |
+
CUTLASS_HOST_DEVICE
|
| 269 |
+
Index stride(int dim) const {
|
| 270 |
+
return layout_.stride().at(dim);
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
/// Returns the layout object's stride in a given physical dimension
|
| 274 |
+
CUTLASS_HOST_DEVICE
|
| 275 |
+
Index & stride(int dim) {
|
| 276 |
+
return layout_.stride().at(dim);
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
/// Computes the offset of an index from the origin of the tensor
|
| 280 |
+
CUTLASS_HOST_DEVICE
|
| 281 |
+
LongIndex offset(TensorCoord const& coord) const {
|
| 282 |
+
return layout_(coord);
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
/// Returns a reference to the element at a given Coord
|
| 286 |
+
CUTLASS_HOST_DEVICE
|
| 287 |
+
Reference at(TensorCoord const& coord) const {
|
| 288 |
+
return data(offset(coord));
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
/// Returns a reference to the element at a given Coord
|
| 292 |
+
CUTLASS_HOST_DEVICE
|
| 293 |
+
Reference operator[](TensorCoord const& coord) const {
|
| 294 |
+
return data(offset(coord));
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
/// Adds an offset to each pointer
|
| 298 |
+
CUTLASS_HOST_DEVICE
|
| 299 |
+
TensorRefPlanarComplex & add_pointer_offset(LongIndex offset_) {
|
| 300 |
+
ptr_ += offset_;
|
| 301 |
+
return *this;
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
/// Adds an offset to each pointer
|
| 305 |
+
CUTLASS_HOST_DEVICE
|
| 306 |
+
TensorRefPlanarComplex & add_coord_offset(TensorCoord const &coord) {
|
| 307 |
+
add_pointer_offset(offset(coord));
|
| 308 |
+
return *this;
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
/// Returns a TensorRef offset by a given amount
|
| 312 |
+
CUTLASS_HOST_DEVICE
|
| 313 |
+
TensorRefPlanarComplex operator+(TensorCoord const& b) const {
|
| 314 |
+
TensorRefPlanarComplex result(*this);
|
| 315 |
+
result.add_coord_offset(b);
|
| 316 |
+
return result;
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
/// Returns a TensorRef offset by a given amount
|
| 320 |
+
CUTLASS_HOST_DEVICE
|
| 321 |
+
TensorRefPlanarComplex & operator+=(TensorCoord const& b) {
|
| 322 |
+
add_coord_offset(b);
|
| 323 |
+
return *this;
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
/// Returns a TensorRef offset by a given amount
|
| 327 |
+
CUTLASS_HOST_DEVICE
|
| 328 |
+
TensorRefPlanarComplex operator-(TensorCoord const& b) const {
|
| 329 |
+
TensorRefPlanarComplex result(*this);
|
| 330 |
+
result.add_pointer_offset(-offset(b));
|
| 331 |
+
return result;
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
/// Returns a TensorRef offset by a given amount
|
| 335 |
+
CUTLASS_HOST_DEVICE
|
| 336 |
+
TensorRefPlanarComplex & operator-=(TensorCoord const& b) {
|
| 337 |
+
add_pointer_offset(-offset(b));
|
| 338 |
+
return *this;
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
/// TensorRef to real-valued tensor
|
| 342 |
+
CUTLASS_HOST_DEVICE
|
| 343 |
+
cutlass::TensorRef<Element, Layout> ref_real() const {
|
| 344 |
+
return cutlass::TensorRef<Element, Layout>(data(), layout());
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
/// TensorRef to real-valued tensor
|
| 348 |
+
CUTLASS_HOST_DEVICE
|
| 349 |
+
cutlass::TensorRef<Element, Layout> ref_imag() const {
|
| 350 |
+
return cutlass::TensorRef<Element, Layout>(imaginary_data(), layout());
|
| 351 |
+
}
|
| 352 |
+
};
|
| 353 |
+
|
| 354 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 355 |
+
|
| 356 |
+
/// Constructs a TensorRef, deducing types from arguments.
|
| 357 |
+
template <
|
| 358 |
+
typename Element,
|
| 359 |
+
typename Layout
|
| 360 |
+
>
|
| 361 |
+
CUTLASS_HOST_DEVICE
|
| 362 |
+
TensorRefPlanarComplex<Element, Layout> make_TensorRefPlanarComplex(
|
| 363 |
+
Element *ptr,
|
| 364 |
+
Layout const &layout,
|
| 365 |
+
int64_t imaginary_stride) {
|
| 366 |
+
|
| 367 |
+
return TensorRefPlanarComplex<Element, Layout>(ptr, layout, imaginary_stride);
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 371 |
+
|
| 372 |
+
} // namespace cutlass
|
| 373 |
+
|
| 374 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_view.h
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Defines a structure containing strides and a pointer to tensor data.
|
| 33 |
+
|
| 34 |
+
TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus,
|
| 35 |
+
it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from
|
| 36 |
+
data storage and is therefore lightweight and may be embedded in larger tensor objects or
|
| 37 |
+
memory structures.
|
| 38 |
+
|
| 39 |
+
See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to
|
| 40 |
+
linear memory.
|
| 41 |
+
*/
|
| 42 |
+
|
| 43 |
+
#pragma once
|
| 44 |
+
|
| 45 |
+
#if !defined(__CUDACC_RTC__)
|
| 46 |
+
#include <cmath>
|
| 47 |
+
#endif
|
| 48 |
+
|
| 49 |
+
#include "cutlass/cutlass.h"
|
| 50 |
+
#include "cutlass/tensor_ref.h"
|
| 51 |
+
|
| 52 |
+
namespace cutlass {
|
| 53 |
+
|
| 54 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 55 |
+
|
| 56 |
+
template <
|
| 57 |
+
/// Data type of element stored within tensor
|
| 58 |
+
typename Element_,
|
| 59 |
+
/// Maps a Coord<Rank_> in the logical tensor index space to the internal n-D array
|
| 60 |
+
typename Layout_
|
| 61 |
+
>
|
| 62 |
+
class TensorView : public TensorRef<Element_, Layout_> {
|
| 63 |
+
public:
|
| 64 |
+
|
| 65 |
+
/// Base tensor reference
|
| 66 |
+
using Base = cutlass::TensorRef<Element_, Layout_>;
|
| 67 |
+
|
| 68 |
+
/// Mapping function from logical coordinate to internal n-D array
|
| 69 |
+
using Layout = Layout_;
|
| 70 |
+
|
| 71 |
+
/// TensorRef pointing to constant memory
|
| 72 |
+
using ConstTensorRef = typename Base::ConstTensorRef;
|
| 73 |
+
|
| 74 |
+
/// Underlying TensorRef type
|
| 75 |
+
using TensorRef = Base;
|
| 76 |
+
|
| 77 |
+
/// Data type of individual access
|
| 78 |
+
using Element = Element_;
|
| 79 |
+
|
| 80 |
+
/// Reference type to an element
|
| 81 |
+
using Reference = Element &;
|
| 82 |
+
|
| 83 |
+
/// Logical rank of tensor index space
|
| 84 |
+
static int const kRank = Layout::kRank;
|
| 85 |
+
|
| 86 |
+
/// Index type
|
| 87 |
+
using Index = typename Layout::Index;
|
| 88 |
+
|
| 89 |
+
/// Long index used for pointer offsets
|
| 90 |
+
using LongIndex = typename Layout::LongIndex;
|
| 91 |
+
|
| 92 |
+
/// Coordinate in logical tensor space
|
| 93 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 94 |
+
|
| 95 |
+
/// Coordinate in storage n-D array
|
| 96 |
+
using Stride = typename Layout::Stride;
|
| 97 |
+
|
| 98 |
+
/// TensorView pointing to constant memory
|
| 99 |
+
using ConstTensorView = TensorView<
|
| 100 |
+
typename platform::remove_const<Element>::type const,
|
| 101 |
+
Layout>;
|
| 102 |
+
|
| 103 |
+
/// TensorView pointing to non-constant memory
|
| 104 |
+
using NonConstTensorView = TensorView<
|
| 105 |
+
typename platform::remove_const<Element>::type,
|
| 106 |
+
Layout>;
|
| 107 |
+
|
| 108 |
+
/// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a
|
| 109 |
+
/// scalar, but degenerate cases such as these are difficult to accommodate without
|
| 110 |
+
/// extensive C++ metaprogramming or support for zero-length arrays.
|
| 111 |
+
static_assert(kRank > 0, "Cannot define a zero-rank TensorRef");
|
| 112 |
+
|
| 113 |
+
private:
|
| 114 |
+
|
| 115 |
+
/// View extent
|
| 116 |
+
TensorCoord extent_;
|
| 117 |
+
|
| 118 |
+
public:
|
| 119 |
+
|
| 120 |
+
//
|
| 121 |
+
// Methods
|
| 122 |
+
//
|
| 123 |
+
|
| 124 |
+
/// Constructs a TensorView object
|
| 125 |
+
CUTLASS_HOST_DEVICE
|
| 126 |
+
TensorView() { }
|
| 127 |
+
|
| 128 |
+
/// Constructs a TensorView object
|
| 129 |
+
CUTLASS_HOST_DEVICE
|
| 130 |
+
TensorView(
|
| 131 |
+
Element *ptr, ///< pointer to start of tensor
|
| 132 |
+
Layout const &layout, ///< layout object containing stride and mapping function
|
| 133 |
+
TensorCoord const &extent ///< size of the view in logical coordinates
|
| 134 |
+
):
|
| 135 |
+
Base(ptr, layout), extent_(extent) {
|
| 136 |
+
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
/// Constructs a TensorView object
|
| 140 |
+
CUTLASS_HOST_DEVICE
|
| 141 |
+
TensorView(
|
| 142 |
+
TensorRef const &ref, ///< pointer and layout object referencing a tensor
|
| 143 |
+
TensorCoord const &extent ///< logical size of tensor
|
| 144 |
+
):
|
| 145 |
+
Base(ref), extent_(extent) {
|
| 146 |
+
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
/// Converting constructor from TensorRef to non-constant data.
|
| 150 |
+
CUTLASS_HOST_DEVICE
|
| 151 |
+
TensorView(
|
| 152 |
+
NonConstTensorView const &view ///< TensorView to non-const data
|
| 153 |
+
):
|
| 154 |
+
Base(view), extent_(view.extent_) { }
|
| 155 |
+
|
| 156 |
+
/// Updates the pointer and layout object
|
| 157 |
+
CUTLASS_HOST_DEVICE
|
| 158 |
+
void reset(Element* ptr, Layout const &layout, TensorCoord const &extent) {
|
| 159 |
+
Base::reset(ptr, layout);
|
| 160 |
+
this->resize(extent);
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
/// Updates the pointer
|
| 164 |
+
CUTLASS_HOST_DEVICE
|
| 165 |
+
void reset(Element* ptr) {
|
| 166 |
+
Base::reset(ptr);
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
/// Changes the size of the view without affecting pointer or layout
|
| 170 |
+
CUTLASS_HOST_DEVICE
|
| 171 |
+
void resize(TensorCoord const &extent) {
|
| 172 |
+
this->extent_ = extent;
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
/// Returns the extent of the view (the size along each logical dimension).
|
| 176 |
+
CUTLASS_HOST_DEVICE
|
| 177 |
+
TensorCoord const& extent() const { return extent_; }
|
| 178 |
+
|
| 179 |
+
/// Returns the extent along a particular logical dimension.
|
| 180 |
+
CUTLASS_HOST_DEVICE
|
| 181 |
+
Index extent(int dim) const { return extent_.at(dim); }
|
| 182 |
+
|
| 183 |
+
/// Returns the number of logical elements
|
| 184 |
+
CUTLASS_HOST_DEVICE
|
| 185 |
+
LongIndex size() const {
|
| 186 |
+
return extent_.product();
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
/// Determines whether a location is within a tensor
|
| 190 |
+
CUTLASS_HOST_DEVICE
|
| 191 |
+
bool contains(TensorCoord const& coord) const {
|
| 192 |
+
CUTLASS_PRAGMA_UNROLL
|
| 193 |
+
for (int dim = 0; dim < kRank; ++dim) {
|
| 194 |
+
if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) {
|
| 195 |
+
return false;
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
return true;
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
/// Returns a TensorRef pointing to the first element of the tensor.
|
| 202 |
+
CUTLASS_HOST_DEVICE
|
| 203 |
+
TensorRef ref() const {
|
| 204 |
+
return TensorRef(this->data(), this->layout());
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
/// Returns a TensorRef pointing to the first element of the tensor.
|
| 208 |
+
CUTLASS_HOST_DEVICE
|
| 209 |
+
ConstTensorRef const_ref() const {
|
| 210 |
+
return ConstTensorRef(this->data(), this->layout());
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
/// Returns a TensorView to const data
|
| 214 |
+
CUTLASS_HOST_DEVICE
|
| 215 |
+
ConstTensorView const_view() const {
|
| 216 |
+
return ConstTensorView(const_ref(), extent_);
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
/// Returns a Tensor_view given location and size quantities
|
| 220 |
+
CUTLASS_HOST_DEVICE
|
| 221 |
+
TensorView subview(
|
| 222 |
+
TensorCoord extent, ///< extent of the resulting view
|
| 223 |
+
TensorCoord const& location = TensorCoord() ///< resulting view's origin within the old view
|
| 224 |
+
) const {
|
| 225 |
+
|
| 226 |
+
TensorView result(this->ref(), extent.clamp(extent_ - location));
|
| 227 |
+
result.add_coord_offset(location);
|
| 228 |
+
return result;
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
/// Returns the number of scalar elements needed to store tensor.
|
| 232 |
+
CUTLASS_HOST_DEVICE
|
| 233 |
+
size_t capacity() const {
|
| 234 |
+
return Base::layout().capacity(extent_);
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
/// Returns a TensorView offset by a given amount
|
| 238 |
+
CUTLASS_HOST_DEVICE
|
| 239 |
+
TensorView operator+(
|
| 240 |
+
TensorCoord const& b ///< offset in the logical coordinate space of the tensor
|
| 241 |
+
) const {
|
| 242 |
+
|
| 243 |
+
TensorView result(*this);
|
| 244 |
+
result.add_pointer_offset(this->offset(b));
|
| 245 |
+
return result;
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
/// Returns a TensorRef offset by a given amount
|
| 249 |
+
CUTLASS_HOST_DEVICE
|
| 250 |
+
TensorView& operator+=(
|
| 251 |
+
TensorCoord const& b ///< offset in the logical coordinate space of the tensor
|
| 252 |
+
) {
|
| 253 |
+
|
| 254 |
+
this->add_pointer_offset(this->offset(b));
|
| 255 |
+
return *this;
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
/// Returns a TensorRef offset by a given amount
|
| 259 |
+
CUTLASS_HOST_DEVICE
|
| 260 |
+
TensorView operator-(
|
| 261 |
+
TensorCoord const& b ///< offset in the logical coordinate space of the tensor
|
| 262 |
+
) const {
|
| 263 |
+
|
| 264 |
+
TensorRef result(*this);
|
| 265 |
+
result.add_pointer_offset(-this->offset(b));
|
| 266 |
+
return result;
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
/// Returns a TensorRef offset by a given amount
|
| 270 |
+
CUTLASS_HOST_DEVICE
|
| 271 |
+
TensorView& operator-=(
|
| 272 |
+
TensorCoord const& b ///< offset in the logical coordinate space of the tensor
|
| 273 |
+
) {
|
| 274 |
+
|
| 275 |
+
this->add_pointer_offset(-this->offset(b));
|
| 276 |
+
return *this;
|
| 277 |
+
}
|
| 278 |
+
};
|
| 279 |
+
|
| 280 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 281 |
+
|
| 282 |
+
/// Constructs a TensorRef, deducing types from arguments.
|
| 283 |
+
template <
|
| 284 |
+
typename Element,
|
| 285 |
+
typename Layout
|
| 286 |
+
>
|
| 287 |
+
CUTLASS_HOST_DEVICE TensorView<Element, Layout> make_TensorView(
|
| 288 |
+
Element *ptr,
|
| 289 |
+
Layout const &layout,
|
| 290 |
+
typename Layout::TensorCoord const &extent) {
|
| 291 |
+
|
| 292 |
+
return TensorView<Element, Layout>(ptr, layout, extent);
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 296 |
+
|
| 297 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tensor_view_planar_complex.h
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Defines a structure containing strides and a pointer to tensor data.
|
| 33 |
+
|
| 34 |
+
TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus,
|
| 35 |
+
it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from
|
| 36 |
+
data storage and is therefore lightweight and may be embedded in larger tensor objects or
|
| 37 |
+
memory structures.
|
| 38 |
+
|
| 39 |
+
See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to
|
| 40 |
+
linear memory.
|
| 41 |
+
*/
|
| 42 |
+
|
| 43 |
+
#pragma once
|
| 44 |
+
|
| 45 |
+
#if !defined(__CUDACC_RTC__)
|
| 46 |
+
#include <cmath>
|
| 47 |
+
#endif
|
| 48 |
+
|
| 49 |
+
#include "cutlass/cutlass.h"
|
| 50 |
+
#include "cutlass/tensor_ref_planar_complex.h"
|
| 51 |
+
#include "cutlass/tensor_view.h" // cutlass::TensorView
|
| 52 |
+
|
| 53 |
+
namespace cutlass {
|
| 54 |
+
|
| 55 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 56 |
+
|
| 57 |
+
template <
|
| 58 |
+
/// Data type of element stored within tensor
|
| 59 |
+
typename Element_,
|
| 60 |
+
/// Maps a Coord<Rank_> in the logical tensor index space to the internal n-D array
|
| 61 |
+
typename Layout_
|
| 62 |
+
>
|
| 63 |
+
class TensorViewPlanarComplex : public TensorRefPlanarComplex<Element_, Layout_> {
|
| 64 |
+
public:
|
| 65 |
+
|
| 66 |
+
/// Base tensor reference
|
| 67 |
+
using Base = cutlass::TensorRefPlanarComplex<Element_, Layout_>;
|
| 68 |
+
|
| 69 |
+
/// Mapping function from logical coordinate to internal n-D array
|
| 70 |
+
using Layout = Layout_;
|
| 71 |
+
|
| 72 |
+
/// TensorRef pointing to constant memory
|
| 73 |
+
using ConstTensorRef = typename Base::ConstTensorRef;
|
| 74 |
+
|
| 75 |
+
/// Underlying TensorRef type
|
| 76 |
+
using TensorRef = Base;
|
| 77 |
+
|
| 78 |
+
/// Data type of individual access
|
| 79 |
+
using Element = Element_;
|
| 80 |
+
|
| 81 |
+
/// Reference type to an element
|
| 82 |
+
using Reference = Element &;
|
| 83 |
+
|
| 84 |
+
/// Logical rank of tensor index space
|
| 85 |
+
static int const kRank = Layout::kRank;
|
| 86 |
+
|
| 87 |
+
/// Index type
|
| 88 |
+
using Index = typename Layout::Index;
|
| 89 |
+
|
| 90 |
+
/// Long index used for pointer offsets
|
| 91 |
+
using LongIndex = typename Layout::LongIndex;
|
| 92 |
+
|
| 93 |
+
/// Coordinate in logical tensor space
|
| 94 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 95 |
+
|
| 96 |
+
/// Coordinate in storage n-D array
|
| 97 |
+
using Stride = typename Layout::Stride;
|
| 98 |
+
|
| 99 |
+
/// TensorView pointing to constant memory
|
| 100 |
+
using ConstTensorView = TensorViewPlanarComplex<
|
| 101 |
+
typename platform::remove_const<Element>::type const,
|
| 102 |
+
Layout>;
|
| 103 |
+
|
| 104 |
+
/// TensorView pointing to non-constant memory
|
| 105 |
+
using NonConstTensorView = TensorViewPlanarComplex<
|
| 106 |
+
typename platform::remove_const<Element>::type,
|
| 107 |
+
Layout>;
|
| 108 |
+
|
| 109 |
+
/// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a
|
| 110 |
+
/// scalar, but degenerate cases such as these are difficult to accommodate without
|
| 111 |
+
/// extensive C++ metaprogramming or support for zero-length arrays.
|
| 112 |
+
static_assert(kRank > 0, "Cannot define a zero-rank TensorRef");
|
| 113 |
+
|
| 114 |
+
private:
|
| 115 |
+
|
| 116 |
+
/// View extent
|
| 117 |
+
TensorCoord extent_;
|
| 118 |
+
|
| 119 |
+
public:
|
| 120 |
+
|
| 121 |
+
//
|
| 122 |
+
// Methods
|
| 123 |
+
//
|
| 124 |
+
|
| 125 |
+
/// Constructs a TensorView object
|
| 126 |
+
CUTLASS_HOST_DEVICE
|
| 127 |
+
TensorViewPlanarComplex(TensorCoord const &extent = TensorCoord()): extent_(extent) {
|
| 128 |
+
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
/// Constructs a TensorView object
|
| 132 |
+
CUTLASS_HOST_DEVICE
|
| 133 |
+
TensorViewPlanarComplex(
|
| 134 |
+
Element *ptr, ///< pointer to start of tensor
|
| 135 |
+
Layout const &layout, ///< layout object containing stride and mapping function
|
| 136 |
+
LongIndex imaginary_stride, ///< stride between real and imaginary part
|
| 137 |
+
TensorCoord const &extent ///< size of the view in logical coordinates
|
| 138 |
+
):
|
| 139 |
+
Base(ptr, layout, imaginary_stride), extent_(extent) {
|
| 140 |
+
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
/// Constructs a TensorView object
|
| 144 |
+
CUTLASS_HOST_DEVICE
|
| 145 |
+
TensorViewPlanarComplex(
|
| 146 |
+
TensorRef const &ref, ///< pointer and layout object referencing a tensor
|
| 147 |
+
TensorCoord const &extent ///< logical size of tensor
|
| 148 |
+
):
|
| 149 |
+
Base(ref), extent_(extent) {
|
| 150 |
+
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
/// Converting constructor from TensorRef to non-constant data.
|
| 154 |
+
CUTLASS_HOST_DEVICE
|
| 155 |
+
TensorViewPlanarComplex(
|
| 156 |
+
NonConstTensorView const &view ///< TensorView to non-const data
|
| 157 |
+
):
|
| 158 |
+
Base(view), extent_(view.extent_) { }
|
| 159 |
+
|
| 160 |
+
/// Updates the pointer and layout object
|
| 161 |
+
CUTLASS_HOST_DEVICE
|
| 162 |
+
void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride, TensorCoord size) {
|
| 163 |
+
Base::reset(ptr, layout, imaginary_stride);
|
| 164 |
+
this->resize(extent_);
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
/// Changes the size of the view without affecting pointer or layout
|
| 168 |
+
CUTLASS_HOST_DEVICE
|
| 169 |
+
void resize(TensorCoord extent) {
|
| 170 |
+
this->extent_ = extent;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
/// Returns the extent of the view (the size along each logical dimension).
|
| 174 |
+
CUTLASS_HOST_DEVICE
|
| 175 |
+
TensorCoord const& extent() const { return extent_; }
|
| 176 |
+
|
| 177 |
+
/// Returns the extent along a particular logical dimension.
|
| 178 |
+
CUTLASS_HOST_DEVICE
|
| 179 |
+
Index extent(int dim) const { return extent_.at(dim); }
|
| 180 |
+
|
| 181 |
+
/// Determines whether a location is within a tensor
|
| 182 |
+
CUTLASS_HOST_DEVICE
|
| 183 |
+
bool contains(TensorCoord const& coord) const {
|
| 184 |
+
CUTLASS_PRAGMA_UNROLL
|
| 185 |
+
for (int dim = 0; dim < kRank; ++dim) {
|
| 186 |
+
if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) {
|
| 187 |
+
return false;
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
return true;
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
/// Returns a TensorRef pointing to the first element of the tensor.
|
| 194 |
+
CUTLASS_HOST_DEVICE
|
| 195 |
+
Base ref() const {
|
| 196 |
+
return Base(this->data(), this->layout(), this->imaginary_stride());
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
/// Returns a TensorRef pointing to the first element of the tensor.
|
| 200 |
+
CUTLASS_HOST_DEVICE
|
| 201 |
+
ConstTensorRef const_ref() const {
|
| 202 |
+
return ConstTensorRef(this->data(), this->layout());
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
/// Returns a TensorView to const data
|
| 206 |
+
CUTLASS_HOST_DEVICE
|
| 207 |
+
ConstTensorView const_view() const {
|
| 208 |
+
return ConstTensorView(const_ref(), extent_);
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
/// Returns a Tensor_view given location and size quantities
|
| 212 |
+
CUTLASS_HOST_DEVICE
|
| 213 |
+
TensorViewPlanarComplex subview(
|
| 214 |
+
TensorCoord extent, ///< extent of the resulting view
|
| 215 |
+
TensorCoord const& location = TensorCoord() ///< resulting view's origin within the old view
|
| 216 |
+
) const {
|
| 217 |
+
|
| 218 |
+
TensorViewPlanarComplex result(this->ref(), extent.clamp(extent_ - location));
|
| 219 |
+
result.add_coord_offset(location);
|
| 220 |
+
return result;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
/// Returns the number of scalar elements needed to store tensor.
|
| 224 |
+
CUTLASS_HOST_DEVICE
|
| 225 |
+
size_t capacity() const {
|
| 226 |
+
return Base::layout().capacity(extent_);
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
/// Returns a TensorView offset by a given amount
|
| 230 |
+
CUTLASS_HOST_DEVICE
|
| 231 |
+
TensorViewPlanarComplex operator+(
|
| 232 |
+
TensorCoord const& b ///< offset in the logical coordinate space of the tensor
|
| 233 |
+
) const {
|
| 234 |
+
|
| 235 |
+
TensorViewPlanarComplex result(*this);
|
| 236 |
+
result.add_pointer_offset(this->offset(b));
|
| 237 |
+
return result;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
/// Returns a TensorRef offset by a given amount
|
| 241 |
+
CUTLASS_HOST_DEVICE
|
| 242 |
+
TensorViewPlanarComplex& operator+=(
|
| 243 |
+
TensorCoord const& b ///< offset in the logical coordinate space of the tensor
|
| 244 |
+
) {
|
| 245 |
+
|
| 246 |
+
this->add_pointer_offset(this->offset(b));
|
| 247 |
+
return *this;
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
/// Returns a TensorRef offset by a given amount
|
| 251 |
+
CUTLASS_HOST_DEVICE
|
| 252 |
+
TensorViewPlanarComplex operator-(
|
| 253 |
+
TensorCoord const& b ///< offset in the logical coordinate space of the tensor
|
| 254 |
+
) const {
|
| 255 |
+
|
| 256 |
+
TensorRef result(*this);
|
| 257 |
+
result.add_pointer_offset(-this->offset(b));
|
| 258 |
+
return result;
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
/// Returns a TensorRef offset by a given amount
|
| 262 |
+
CUTLASS_HOST_DEVICE
|
| 263 |
+
TensorViewPlanarComplex& operator-=(
|
| 264 |
+
TensorCoord const& b ///< offset in the logical coordinate space of the tensor
|
| 265 |
+
) {
|
| 266 |
+
|
| 267 |
+
this->add_pointer_offset(-this->offset(b));
|
| 268 |
+
return *this;
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
/// TensorRef to real-valued tensor
|
| 272 |
+
CUTLASS_HOST_DEVICE
|
| 273 |
+
cutlass::TensorView<Element, Layout> view_real() const {
|
| 274 |
+
return cutlass::TensorView<Element, Layout>(this->data(), this->layout(), extent_);
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
/// TensorRef to real-valued tensor
|
| 278 |
+
CUTLASS_HOST_DEVICE
|
| 279 |
+
cutlass::TensorView<Element, Layout> view_imag() const {
|
| 280 |
+
return cutlass::TensorView<Element, Layout>(this->imaginary_data(), this->layout(), extent_);
|
| 281 |
+
}
|
| 282 |
+
};
|
| 283 |
+
|
| 284 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 285 |
+
|
| 286 |
+
/// Constructs a TensorRef, deducing types from arguments.
|
| 287 |
+
template <
|
| 288 |
+
typename Element,
|
| 289 |
+
typename Layout
|
| 290 |
+
>
|
| 291 |
+
CUTLASS_HOST_DEVICE TensorViewPlanarComplex<Element, Layout> make_TensorViewPlanarComplex(
|
| 292 |
+
Element *ptr,
|
| 293 |
+
Layout const &layout,
|
| 294 |
+
typename Layout::LongIndex imaginary_stride,
|
| 295 |
+
typename Layout::TensorCoord const &extent) {
|
| 296 |
+
|
| 297 |
+
return TensorViewPlanarComplex<Element, Layout>(ptr, layout, imaginary_stride, extent);
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 301 |
+
|
| 302 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/tfloat32.h
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*!
|
| 32 |
+
\file
|
| 33 |
+
\brief Defines a proxy class for storing Tensor Float 32 data type.
|
| 34 |
+
*/
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#if defined(__CUDACC_RTC__)
|
| 38 |
+
#include "cutlass/floating_point_nvrtc.h"
|
| 39 |
+
#else
|
| 40 |
+
#include <cmath>
|
| 41 |
+
#include <limits>
|
| 42 |
+
#include <cstdint>
|
| 43 |
+
#include <cstring> // std::memcpy
|
| 44 |
+
#endif
|
| 45 |
+
|
| 46 |
+
#include "cutlass/cutlass.h"
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
|
| 50 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 51 |
+
|
| 52 |
+
/// Tensor Float 32 data type
|
| 53 |
+
struct alignas(4) tfloat32_t {
|
| 54 |
+
|
| 55 |
+
//
|
| 56 |
+
// Data members
|
| 57 |
+
//
|
| 58 |
+
|
| 59 |
+
/// Storage type
|
| 60 |
+
uint32_t storage;
|
| 61 |
+
|
| 62 |
+
//
|
| 63 |
+
// Methods
|
| 64 |
+
//
|
| 65 |
+
private:
|
| 66 |
+
CUTLASS_HOST_DEVICE
|
| 67 |
+
static uint32_t float_to_storage(float s) {
|
| 68 |
+
#if defined(__CUDA_ARCH__)
|
| 69 |
+
uint32_t result = reinterpret_cast<uint32_t const &>(s);
|
| 70 |
+
#else
|
| 71 |
+
uint32_t result;
|
| 72 |
+
std::memcpy(&result, &s, sizeof(float));
|
| 73 |
+
#endif
|
| 74 |
+
return result;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
public:
|
| 78 |
+
/// Constructs from an unsigned int
|
| 79 |
+
CUTLASS_HOST_DEVICE
|
| 80 |
+
static tfloat32_t bitcast(uint32_t x) {
|
| 81 |
+
tfloat32_t h;
|
| 82 |
+
h.storage = x;
|
| 83 |
+
return h;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
/// Emulated rounding is fast in device code
|
| 87 |
+
CUTLASS_HOST_DEVICE
|
| 88 |
+
static tfloat32_t round_half_ulp_truncate(float const &s) {
|
| 89 |
+
uint32_t x = float_to_storage(s);
|
| 90 |
+
|
| 91 |
+
#if defined(__CUDA_ARCH__)
|
| 92 |
+
if (::isfinite(s)) {
|
| 93 |
+
x += 0x1000u;
|
| 94 |
+
}
|
| 95 |
+
#else
|
| 96 |
+
if (std::isfinite(s)) {
|
| 97 |
+
x += 0x1000u;
|
| 98 |
+
}
|
| 99 |
+
#endif
|
| 100 |
+
|
| 101 |
+
return tfloat32_t::bitcast(x);
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
tfloat32_t() = default;
|
| 105 |
+
|
| 106 |
+
/// Floating-point conversion - round toward nearest even
|
| 107 |
+
CUTLASS_HOST_DEVICE
|
| 108 |
+
explicit tfloat32_t(float x): storage(round_half_ulp_truncate(x).raw()) { }
|
| 109 |
+
|
| 110 |
+
// Conversion from double (this rounds twice)
|
| 111 |
+
CUTLASS_HOST_DEVICE
|
| 112 |
+
explicit tfloat32_t(double x): tfloat32_t(float(x)) { }
|
| 113 |
+
|
| 114 |
+
/// Integer conversion - round toward zero
|
| 115 |
+
CUTLASS_HOST_DEVICE
|
| 116 |
+
explicit tfloat32_t(int x) {
|
| 117 |
+
float flt = static_cast<float>(x);
|
| 118 |
+
#if defined(__CUDA_ARCH__)
|
| 119 |
+
storage = reinterpret_cast<uint32_t const &>(flt);
|
| 120 |
+
#else
|
| 121 |
+
std::memcpy(&storage, &flt, sizeof(storage));
|
| 122 |
+
#endif
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
// Conversion to float
|
| 126 |
+
CUTLASS_HOST_DEVICE
|
| 127 |
+
operator float() const {
|
| 128 |
+
|
| 129 |
+
// Conversions to IEEE single-precision requires clearing dont-care bits
|
| 130 |
+
// of the mantissa.
|
| 131 |
+
unsigned bits = (storage & ~0x1fffu);
|
| 132 |
+
|
| 133 |
+
#if defined(__CUDA_ARCH__)
|
| 134 |
+
return reinterpret_cast<float const &>(bits);
|
| 135 |
+
#else
|
| 136 |
+
float flt;
|
| 137 |
+
std::memcpy(&flt, &bits, sizeof(flt));
|
| 138 |
+
return flt;
|
| 139 |
+
#endif
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
/// Converts to double
|
| 143 |
+
CUTLASS_HOST_DEVICE
|
| 144 |
+
explicit operator double() const {
|
| 145 |
+
return double(float(*this));
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
/// Converts to int
|
| 149 |
+
CUTLASS_HOST_DEVICE
|
| 150 |
+
explicit operator int() const {
|
| 151 |
+
return int(float(*this));
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
/// Casts to bool
|
| 155 |
+
CUTLASS_HOST_DEVICE
|
| 156 |
+
explicit operator bool() const {
|
| 157 |
+
return (float(*this) != 0.0f);
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
/// Obtains raw bits
|
| 161 |
+
CUTLASS_HOST_DEVICE
|
| 162 |
+
uint32_t raw() const {
|
| 163 |
+
return storage;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
/// Returns the sign bit
|
| 167 |
+
CUTLASS_HOST_DEVICE
|
| 168 |
+
bool signbit() const {
|
| 169 |
+
return ((raw() & 0x80000000) != 0);
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
/// Returns the biased exponent
|
| 173 |
+
CUTLASS_HOST_DEVICE
|
| 174 |
+
int exponent_biased() const {
|
| 175 |
+
return int((raw() >> 23) & 0x0ff);
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
/// Returns the unbiased exponent
|
| 179 |
+
CUTLASS_HOST_DEVICE
|
| 180 |
+
int exponent() const {
|
| 181 |
+
return exponent_biased() - 127;
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
/// Returns the mantissa
|
| 185 |
+
CUTLASS_HOST_DEVICE
|
| 186 |
+
int mantissa() const {
|
| 187 |
+
return int(raw() & 0x7fffff);
|
| 188 |
+
}
|
| 189 |
+
};
|
| 190 |
+
|
| 191 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 192 |
+
|
| 193 |
+
CUTLASS_HOST_DEVICE
|
| 194 |
+
bool signbit(cutlass::tfloat32_t const& h) {
|
| 195 |
+
return h.signbit();
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
CUTLASS_HOST_DEVICE
|
| 199 |
+
cutlass::tfloat32_t abs(cutlass::tfloat32_t const& h) {
|
| 200 |
+
return cutlass::tfloat32_t::bitcast(h.raw() & 0x7fffffff);
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
CUTLASS_HOST_DEVICE
|
| 204 |
+
bool isnan(cutlass::tfloat32_t const& h) {
|
| 205 |
+
return (h.exponent_biased() == 0x0ff) && h.mantissa();
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
CUTLASS_HOST_DEVICE
|
| 209 |
+
bool isfinite(cutlass::tfloat32_t const& h) {
|
| 210 |
+
return (h.exponent_biased() != 0x0ff);
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
CUTLASS_HOST_DEVICE
|
| 214 |
+
cutlass::tfloat32_t nan_tf32(const char*) {
|
| 215 |
+
// NVIDIA canonical NaN
|
| 216 |
+
return cutlass::tfloat32_t::bitcast(0x7fffffff);
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
CUTLASS_HOST_DEVICE
|
| 220 |
+
bool isinf(cutlass::tfloat32_t const& h) {
|
| 221 |
+
return (h.exponent_biased() == 0x0ff) && !h.mantissa();
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
CUTLASS_HOST_DEVICE
|
| 225 |
+
bool isnormal(cutlass::tfloat32_t const& h) {
|
| 226 |
+
return h.exponent_biased() && h.exponent_biased() != 0x0ff;
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
CUTLASS_HOST_DEVICE
|
| 230 |
+
int fpclassify(cutlass::tfloat32_t const& h) {
|
| 231 |
+
int exp = h.exponent_biased();
|
| 232 |
+
int mantissa = h.mantissa();
|
| 233 |
+
if (exp == 0x0ff) {
|
| 234 |
+
if (mantissa) {
|
| 235 |
+
return FP_NAN;
|
| 236 |
+
}
|
| 237 |
+
else {
|
| 238 |
+
return FP_INFINITE;
|
| 239 |
+
}
|
| 240 |
+
}
|
| 241 |
+
else if (!exp) {
|
| 242 |
+
if (mantissa) {
|
| 243 |
+
return FP_SUBNORMAL;
|
| 244 |
+
}
|
| 245 |
+
else {
|
| 246 |
+
return FP_ZERO;
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
return FP_NORMAL;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
CUTLASS_HOST_DEVICE
|
| 253 |
+
cutlass::tfloat32_t sqrt(cutlass::tfloat32_t const& h) {
|
| 254 |
+
#if defined(__CUDACC_RTC__)
|
| 255 |
+
return cutlass::tfloat32_t(sqrtf(float(h)));
|
| 256 |
+
#else
|
| 257 |
+
return cutlass::tfloat32_t(std::sqrt(float(h)));
|
| 258 |
+
#endif
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
CUTLASS_HOST_DEVICE
|
| 262 |
+
tfloat32_t copysign(tfloat32_t const& a, tfloat32_t const& b) {
|
| 263 |
+
|
| 264 |
+
uint32_t a_mag = (a.raw() & 0x7fffffff);
|
| 265 |
+
uint32_t b_sign = (b.raw() & 0x80000000);
|
| 266 |
+
uint32_t result = (a_mag | b_sign);
|
| 267 |
+
|
| 268 |
+
return tfloat32_t::bitcast(result);
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 272 |
+
|
| 273 |
+
} // namespace cutlass
|
| 274 |
+
|
| 275 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 276 |
+
//
|
| 277 |
+
// Standard Library operations and definitions
|
| 278 |
+
//
|
| 279 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 280 |
+
|
| 281 |
+
namespace std {
|
| 282 |
+
|
| 283 |
+
#if !defined(__CUDACC_RTC__)
|
| 284 |
+
/// Numeric limits
|
| 285 |
+
template <>
|
| 286 |
+
struct numeric_limits<cutlass::tfloat32_t> {
|
| 287 |
+
static bool const is_specialized = true;
|
| 288 |
+
static bool const is_signed = true;
|
| 289 |
+
static bool const is_integer = false;
|
| 290 |
+
static bool const is_exact = false;
|
| 291 |
+
static bool const has_infinity = true;
|
| 292 |
+
static bool const has_quiet_NaN = true;
|
| 293 |
+
static bool const has_signaling_NaN = false;
|
| 294 |
+
static std::float_denorm_style const has_denorm = std::denorm_present;
|
| 295 |
+
static bool const has_denorm_loss = true;
|
| 296 |
+
static std::float_round_style const round_style = std::round_to_nearest;
|
| 297 |
+
static bool const is_iec559 = false;
|
| 298 |
+
static bool const is_bounded = true;
|
| 299 |
+
static bool const is_modulo = false;
|
| 300 |
+
static int const digits = 19;
|
| 301 |
+
|
| 302 |
+
/// Least positive value
|
| 303 |
+
static cutlass::tfloat32_t min() { return cutlass::tfloat32_t::bitcast(0x01); }
|
| 304 |
+
|
| 305 |
+
/// Minimum finite value
|
| 306 |
+
static cutlass::tfloat32_t lowest() { return cutlass::tfloat32_t::bitcast(0xff7fffff); }
|
| 307 |
+
|
| 308 |
+
/// Maximum finite value
|
| 309 |
+
static cutlass::tfloat32_t max() { return cutlass::tfloat32_t::bitcast(0x7f7fffff); }
|
| 310 |
+
|
| 311 |
+
/// Returns smallest finite value
|
| 312 |
+
static cutlass::tfloat32_t epsilon() { return cutlass::tfloat32_t::bitcast(0x1000); }
|
| 313 |
+
|
| 314 |
+
/// Returns smallest finite value
|
| 315 |
+
static cutlass::tfloat32_t round_error() { return cutlass::tfloat32_t(0.5f); }
|
| 316 |
+
|
| 317 |
+
/// Returns smallest finite value
|
| 318 |
+
static cutlass::tfloat32_t infinity() { return cutlass::tfloat32_t::bitcast(0x7f800000); }
|
| 319 |
+
|
| 320 |
+
/// Returns smallest finite value
|
| 321 |
+
static cutlass::tfloat32_t quiet_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); }
|
| 322 |
+
|
| 323 |
+
/// Returns smallest finite value
|
| 324 |
+
static cutlass::tfloat32_t signaling_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); }
|
| 325 |
+
|
| 326 |
+
/// Returns smallest finite value
|
| 327 |
+
static cutlass::tfloat32_t denorm_min() { return cutlass::tfloat32_t::bitcast(0x1); }
|
| 328 |
+
};
|
| 329 |
+
#endif
|
| 330 |
+
|
| 331 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 332 |
+
|
| 333 |
+
} // namespace std
|
| 334 |
+
|
| 335 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 336 |
+
//
|
| 337 |
+
// Arithmetic operators
|
| 338 |
+
//
|
| 339 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 340 |
+
|
| 341 |
+
namespace cutlass {
|
| 342 |
+
|
| 343 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 344 |
+
|
| 345 |
+
CUTLASS_HOST_DEVICE
|
| 346 |
+
bool operator==(tfloat32_t const& lhs, tfloat32_t const& rhs) {
|
| 347 |
+
return float(lhs) == float(rhs);
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
CUTLASS_HOST_DEVICE
|
| 351 |
+
bool operator!=(tfloat32_t const& lhs, tfloat32_t const& rhs) {
|
| 352 |
+
return float(lhs) != float(rhs);
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
CUTLASS_HOST_DEVICE
|
| 356 |
+
bool operator<(tfloat32_t const& lhs, tfloat32_t const& rhs) {
|
| 357 |
+
return float(lhs) < float(rhs);
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
CUTLASS_HOST_DEVICE
|
| 361 |
+
bool operator<=(tfloat32_t const& lhs, tfloat32_t const& rhs) {
|
| 362 |
+
return float(lhs) <= float(rhs);
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
CUTLASS_HOST_DEVICE
|
| 366 |
+
bool operator>(tfloat32_t const& lhs, tfloat32_t const& rhs) {
|
| 367 |
+
return float(lhs) > float(rhs);
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
CUTLASS_HOST_DEVICE
|
| 371 |
+
bool operator>=(tfloat32_t const& lhs, tfloat32_t const& rhs) {
|
| 372 |
+
return float(lhs) >= float(rhs);
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
CUTLASS_HOST_DEVICE
|
| 376 |
+
tfloat32_t operator+(tfloat32_t const& lhs, tfloat32_t const& rhs) {
|
| 377 |
+
return tfloat32_t(float(lhs) + float(rhs));
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
CUTLASS_HOST_DEVICE
|
| 382 |
+
tfloat32_t operator-(tfloat32_t const& lhs) {
|
| 383 |
+
return tfloat32_t::bitcast(0x80000000 ^ lhs.raw());
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
CUTLASS_HOST_DEVICE
|
| 387 |
+
tfloat32_t operator-(tfloat32_t const& lhs, tfloat32_t const& rhs) {
|
| 388 |
+
return tfloat32_t(float(lhs) - float(rhs));
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
CUTLASS_HOST_DEVICE
|
| 392 |
+
tfloat32_t operator*(tfloat32_t const& lhs, tfloat32_t const& rhs) {
|
| 393 |
+
return tfloat32_t(float(lhs) * float(rhs));
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
CUTLASS_HOST_DEVICE
|
| 397 |
+
tfloat32_t operator/(tfloat32_t const& lhs, tfloat32_t const& rhs) {
|
| 398 |
+
return tfloat32_t(float(lhs) / float(rhs));
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
CUTLASS_HOST_DEVICE
|
| 402 |
+
tfloat32_t& operator+=(tfloat32_t & lhs, tfloat32_t const& rhs) {
|
| 403 |
+
lhs = tfloat32_t(float(lhs) + float(rhs));
|
| 404 |
+
return lhs;
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
CUTLASS_HOST_DEVICE
|
| 408 |
+
tfloat32_t& operator-=(tfloat32_t & lhs, tfloat32_t const& rhs) {
|
| 409 |
+
lhs = tfloat32_t(float(lhs) - float(rhs));
|
| 410 |
+
return lhs;
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
CUTLASS_HOST_DEVICE
|
| 414 |
+
tfloat32_t& operator*=(tfloat32_t & lhs, tfloat32_t const& rhs) {
|
| 415 |
+
lhs = tfloat32_t(float(lhs) * float(rhs));
|
| 416 |
+
return lhs;
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
CUTLASS_HOST_DEVICE
|
| 420 |
+
tfloat32_t& operator/=(tfloat32_t & lhs, tfloat32_t const& rhs) {
|
| 421 |
+
lhs = tfloat32_t(float(lhs) / float(rhs));
|
| 422 |
+
return lhs;
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
CUTLASS_HOST_DEVICE
|
| 426 |
+
tfloat32_t& operator++(tfloat32_t & lhs) {
|
| 427 |
+
float tmp(lhs);
|
| 428 |
+
++tmp;
|
| 429 |
+
lhs = tfloat32_t(tmp);
|
| 430 |
+
return lhs;
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
CUTLASS_HOST_DEVICE
|
| 434 |
+
tfloat32_t& operator--(tfloat32_t & lhs) {
|
| 435 |
+
float tmp(lhs);
|
| 436 |
+
--tmp;
|
| 437 |
+
lhs = tfloat32_t(tmp);
|
| 438 |
+
return lhs;
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
CUTLASS_HOST_DEVICE
|
| 442 |
+
tfloat32_t operator++(tfloat32_t & lhs, int) {
|
| 443 |
+
tfloat32_t ret(lhs);
|
| 444 |
+
float tmp(lhs);
|
| 445 |
+
tmp++;
|
| 446 |
+
lhs = tfloat32_t(tmp);
|
| 447 |
+
return ret;
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
CUTLASS_HOST_DEVICE
|
| 451 |
+
tfloat32_t operator--(tfloat32_t & lhs, int) {
|
| 452 |
+
tfloat32_t ret(lhs);
|
| 453 |
+
float tmp(lhs);
|
| 454 |
+
tmp--;
|
| 455 |
+
lhs = tfloat32_t(tmp);
|
| 456 |
+
return ret;
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 460 |
+
|
| 461 |
+
} // namespace cutlass
|
| 462 |
+
|
| 463 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 464 |
+
|
| 465 |
+
//
|
| 466 |
+
// User-defined literals
|
| 467 |
+
//
|
| 468 |
+
|
| 469 |
+
CUTLASS_HOST_DEVICE
|
| 470 |
+
cutlass::tfloat32_t operator "" _tf32(long double x) {
|
| 471 |
+
return cutlass::tfloat32_t(float(x));
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
CUTLASS_HOST_DEVICE
|
| 475 |
+
cutlass::tfloat32_t operator "" _tf32(unsigned long long int x) {
|
| 476 |
+
return cutlass::tfloat32_t(int(x));
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/thread/matrix.h
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Defines a matrix object intended for storing data in registers and operations within
|
| 33 |
+
a CUDA thread.
|
| 34 |
+
*/
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/array.h"
|
| 39 |
+
#include "cutlass/matrix_coord.h"
|
| 40 |
+
|
| 41 |
+
namespace cutlass {
|
| 42 |
+
namespace thread {
|
| 43 |
+
|
| 44 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
|
| 46 |
+
/// Per-thread matrix object storing a packed matrix
|
| 47 |
+
template <
|
| 48 |
+
typename Element,
|
| 49 |
+
int Rows,
|
| 50 |
+
int Columns,
|
| 51 |
+
typename Layout = layout::RowMajor
|
| 52 |
+
>
|
| 53 |
+
class Matrix : public Array<Element, Rows * Columns> {
|
| 54 |
+
public:
|
| 55 |
+
|
| 56 |
+
// Verify layout refers to a rank=2 matrix.
|
| 57 |
+
static_assert(
|
| 58 |
+
Layout::kRank == 2,
|
| 59 |
+
"Layout type must refer to a rank=2 matrix");
|
| 60 |
+
|
| 61 |
+
/// Base type
|
| 62 |
+
using Base = Array<Element, Rows * Columns>;
|
| 63 |
+
|
| 64 |
+
/// Element type
|
| 65 |
+
using Element = Element_;
|
| 66 |
+
|
| 67 |
+
/// Number of rows
|
| 68 |
+
static int const kRows = Rows;
|
| 69 |
+
|
| 70 |
+
/// Number of columns
|
| 71 |
+
static int const kColumns = Columns;
|
| 72 |
+
|
| 73 |
+
/// Layout within the array
|
| 74 |
+
using Layout = Layout_;
|
| 75 |
+
|
| 76 |
+
/// Reference type to an element
|
| 77 |
+
using Reference = Element &;
|
| 78 |
+
|
| 79 |
+
/// Logical rank of tensor index space
|
| 80 |
+
static int const kRank = 2;
|
| 81 |
+
|
| 82 |
+
/// Index type
|
| 83 |
+
using Index = typename Layout::Index;
|
| 84 |
+
|
| 85 |
+
/// Long index used for pointer offsets
|
| 86 |
+
using LongIndex = typename Layout::LongIndex;
|
| 87 |
+
|
| 88 |
+
/// Coordinate in logical tensor space
|
| 89 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 90 |
+
|
| 91 |
+
/// Stride type
|
| 92 |
+
using Stride = typename Layout::Stride;
|
| 93 |
+
|
| 94 |
+
/// TensorRef to matrix object
|
| 95 |
+
using TensorRef = TensorRef<Element, kRank, Layout>;
|
| 96 |
+
|
| 97 |
+
/// TensorRef to constant matrix object
|
| 98 |
+
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
| 99 |
+
|
| 100 |
+
/// TensorRef to matrix object
|
| 101 |
+
using TensorView = TensorView<Element, kRank, Layout>;
|
| 102 |
+
|
| 103 |
+
/// TensorRef to constant matrix object
|
| 104 |
+
using ConstTensorView = typename TensorView::ConstTensorView;
|
| 105 |
+
|
| 106 |
+
/// Diagonal vector
|
| 107 |
+
using Diagonal = Vector<Element, __NV_STD_MIN(kRows, kColumns)>;
|
| 108 |
+
|
| 109 |
+
private:
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
public:
|
| 113 |
+
|
| 114 |
+
//
|
| 115 |
+
// Methods
|
| 116 |
+
//
|
| 117 |
+
|
| 118 |
+
/// Returns the size of the object
|
| 119 |
+
CUTLASS_HOST_DEVICE
|
| 120 |
+
static MatrixCoord extent() {
|
| 121 |
+
return make_Coord(kRows, kColumns);
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
/// Returns the layout object
|
| 125 |
+
CUTLASS_HOST_DEVICE
|
| 126 |
+
static Layout layout() {
|
| 127 |
+
return Layout::packed(extent());
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
/// Ctor
|
| 131 |
+
CUTLASS_HOST_DEVICE
|
| 132 |
+
Matrix() { }
|
| 133 |
+
|
| 134 |
+
/// Ctor
|
| 135 |
+
CUTLASS_HOST_DEVICE
|
| 136 |
+
Matrix(Diagonal const &diag) {
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
/// Returns a TensorRef pointing to the first element of the tensor.
|
| 140 |
+
CUTLASS_HOST_DEVICE
|
| 141 |
+
TensorRef ref() {
|
| 142 |
+
return TensorRef(this->data(), layout());
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
/// Returns a TensorRef pointing to the first element of the tensor.
|
| 146 |
+
CUTLASS_HOST_DEVICE
|
| 147 |
+
ConstTensorRef const_ref() const {
|
| 148 |
+
return ConstTensorRef(this->data(), layout());
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
/// Returns a TensorRef pointing to the first element of the tensor.
|
| 152 |
+
CUTLASS_HOST_DEVICE
|
| 153 |
+
TensorView view() {
|
| 154 |
+
return TensorView(ref(), extent());
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
/// Returns a TensorView to const data
|
| 158 |
+
CUTLASS_HOST_DEVICE
|
| 159 |
+
ConstTensorView const_view() const {
|
| 160 |
+
return ConstTensorView(const_ref(), extent());
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
/// Returns a reference to the element at a given Coord
|
| 164 |
+
CUTLASS_HOST_DEVICE
|
| 165 |
+
Reference at(MatrixCoord const& coord) const {
|
| 166 |
+
typename Base::size_type offset_(layout().offset(coord));
|
| 167 |
+
return Base::at(offset_);
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
/// Returns the number of scalar elements needed to store tensor.
|
| 171 |
+
CUTLASS_HOST_DEVICE
|
| 172 |
+
LongIndex capacity() const {
|
| 173 |
+
return LongIndex(Base::size());
|
| 174 |
+
}
|
| 175 |
+
};
|
| 176 |
+
|
| 177 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 178 |
+
|
| 179 |
+
/// Column vector defined as a matrix with exactly one column
|
| 180 |
+
template <
|
| 181 |
+
typename Element,
|
| 182 |
+
int Rows,
|
| 183 |
+
typename Layout = layout::ColumnMajor
|
| 184 |
+
>
|
| 185 |
+
using ColumnVector = Matrix<Element, Rows, 1, Layout>;
|
| 186 |
+
|
| 187 |
+
/// Row vector defined as a matrix with exactly one row
|
| 188 |
+
template <
|
| 189 |
+
typename Element,
|
| 190 |
+
int Columns,
|
| 191 |
+
typename Layout = layout::RowMajor
|
| 192 |
+
>
|
| 193 |
+
using RowVector = Matrix<Element, 1, Columns, Layout>;
|
| 194 |
+
|
| 195 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 196 |
+
|
| 197 |
+
} // namespace thread
|
| 198 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/trace.h
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Helpers for optionally tracing through code when debugging.
|
| 33 |
+
|
| 34 |
+
This file is to be included after all other headers.
|
| 35 |
+
*/
|
| 36 |
+
|
| 37 |
+
#pragma once
|
| 38 |
+
|
| 39 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 40 |
+
|
| 41 |
+
// Tracing options
|
| 42 |
+
#ifndef CUTLASS_DEBUG_TRACE_LEVEL
|
| 43 |
+
#define CUTLASS_DEBUG_TRACE_LEVEL 0
|
| 44 |
+
#endif
|
| 45 |
+
|
| 46 |
+
#if CUTLASS_DEBUG_TRACE_LEVEL
|
| 47 |
+
#include <iostream>
|
| 48 |
+
#include "cutlass/core_io.h"
|
| 49 |
+
#if defined(__CUDA_ARCH__)
|
| 50 |
+
#define CUTLASS_TRACE_HOST(x)
|
| 51 |
+
#else
|
| 52 |
+
#define CUTLASS_TRACE_HOST(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; }
|
| 53 |
+
#endif
|
| 54 |
+
#else
|
| 55 |
+
#define CUTLASS_TRACE_HOST(x)
|
| 56 |
+
#endif
|
| 57 |
+
|
| 58 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp
ADDED
|
@@ -0,0 +1,754 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates implementing how threads are mapped to a given tile.
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cute/arch/mma_sm90_gmma.hpp"
|
| 38 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 39 |
+
|
| 40 |
+
namespace cutlass {
|
| 41 |
+
namespace transform {
|
| 42 |
+
namespace collective {
|
| 43 |
+
|
| 44 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
|
| 46 |
+
namespace detail {
|
| 47 |
+
using namespace cute;
|
| 48 |
+
|
| 49 |
+
template <bool Transpose, class SmemLayoutAtom, class ElementType>
|
| 50 |
+
constexpr auto
|
| 51 |
+
gmma_smem_transpose_or_passthrough() {
|
| 52 |
+
if constexpr (Transpose) {
|
| 53 |
+
if constexpr (cute::is_same_v<GMMA::Layout_MN_SW128_Atom<ElementType>, SmemLayoutAtom>) {
|
| 54 |
+
return GMMA::Layout_K_SW128_Atom<ElementType>{};
|
| 55 |
+
}
|
| 56 |
+
else if constexpr (cute::is_same_v<GMMA::Layout_MN_SW64_Atom<ElementType>, SmemLayoutAtom>) {
|
| 57 |
+
return GMMA::Layout_K_SW64_Atom<ElementType>{};
|
| 58 |
+
}
|
| 59 |
+
else if constexpr (cute::is_same_v<GMMA::Layout_MN_SW32_Atom<ElementType>, SmemLayoutAtom>) {
|
| 60 |
+
return GMMA::Layout_K_SW32_Atom<ElementType>{};
|
| 61 |
+
}
|
| 62 |
+
else if constexpr (cute::is_same_v<GMMA::Layout_MN_INTER_Atom<ElementType>, SmemLayoutAtom>) {
|
| 63 |
+
return GMMA::Layout_K_INTER_Atom<ElementType>{};
|
| 64 |
+
}
|
| 65 |
+
else {
|
| 66 |
+
static_assert(cutlass::detail::dependent_false<SmemLayoutAtom>, "Unsupported Layout_SW_Atom for B SMEM transposition");
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
else {
|
| 70 |
+
return SmemLayoutAtom{};
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
template <class SmemCopyAtom, class ElementType>
|
| 75 |
+
constexpr auto
|
| 76 |
+
use_universal_transposition() {
|
| 77 |
+
if constexpr (sizeof(ElementType) == 1) {
|
| 78 |
+
return !cute::is_same_v<GMMA::Layout_MN_SW128_Atom<ElementType>, SmemCopyAtom>;
|
| 79 |
+
}
|
| 80 |
+
else if constexpr (sizeof(ElementType) == 4){
|
| 81 |
+
// Only universal transposition can handle SW64 and Non swizzle SMEM layout
|
| 82 |
+
if constexpr (cute::is_same_v<GMMA::Layout_MN_SW64_Atom<ElementType>, SmemCopyAtom> ||
|
| 83 |
+
cute::is_same_v<GMMA::Layout_MN_INTER_Atom<ElementType>, SmemCopyAtom>) {
|
| 84 |
+
return true;
|
| 85 |
+
}
|
| 86 |
+
else {
|
| 87 |
+
return false;
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
else {
|
| 91 |
+
static_assert(cutlass::detail::dependent_false<ElementType>, "Unsupported ElementType for B SMEM transposition");
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
template<
|
| 96 |
+
class TiledMma_,
|
| 97 |
+
class SmemLayoutB_,
|
| 98 |
+
class SmemLayoutAtomB_,
|
| 99 |
+
class ElementB_>
|
| 100 |
+
class NoTranspositionOperandB {
|
| 101 |
+
public:
|
| 102 |
+
using TiledMma = TiledMma_;
|
| 103 |
+
using SmemLayoutB = SmemLayoutB_;
|
| 104 |
+
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
| 105 |
+
using ElementB = ElementB_;
|
| 106 |
+
|
| 107 |
+
constexpr CUTLASS_HOST_DEVICE
|
| 108 |
+
NoTranspositionOperandB(
|
| 109 |
+
int,
|
| 110 |
+
int,
|
| 111 |
+
TiledMma,
|
| 112 |
+
SmemLayoutB,
|
| 113 |
+
SmemLayoutAtomB,
|
| 114 |
+
ElementB) { }
|
| 115 |
+
|
| 116 |
+
template <
|
| 117 |
+
class TensorSmemB,
|
| 118 |
+
class TensorTransposedSmemB>
|
| 119 |
+
CUTLASS_DEVICE void operator()(
|
| 120 |
+
TensorSmemB const&,
|
| 121 |
+
TensorTransposedSmemB const&,
|
| 122 |
+
int, int) { }
|
| 123 |
+
|
| 124 |
+
CUTLASS_DEVICE void synchronize(int) { }
|
| 125 |
+
|
| 126 |
+
CUTLASS_DEVICE void synchronize() { }
|
| 127 |
+
|
| 128 |
+
template <
|
| 129 |
+
class TensorSmemB,
|
| 130 |
+
class TensorTransposedSmemB>
|
| 131 |
+
CUTLASS_DEVICE void transpose(
|
| 132 |
+
TensorSmemB const&,
|
| 133 |
+
TensorTransposedSmemB const&,
|
| 134 |
+
int) { }
|
| 135 |
+
};
|
| 136 |
+
|
| 137 |
+
template<
|
| 138 |
+
class TiledMma_,
|
| 139 |
+
class SmemLayoutB_,
|
| 140 |
+
class SmemLayoutAtomB_,
|
| 141 |
+
class ElementB_>
|
| 142 |
+
class UniversalTranspositionOperandB {
|
| 143 |
+
public:
|
| 144 |
+
using TiledMma = TiledMma_;
|
| 145 |
+
using SmemLayoutB = SmemLayoutB_;
|
| 146 |
+
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
| 147 |
+
using ElementB = ElementB_;
|
| 148 |
+
|
| 149 |
+
constexpr CUTLASS_HOST_DEVICE
|
| 150 |
+
UniversalTranspositionOperandB(
|
| 151 |
+
int warp_idx_,
|
| 152 |
+
int warp_group_thread_idx_,
|
| 153 |
+
TiledMma,
|
| 154 |
+
SmemLayoutB,
|
| 155 |
+
SmemLayoutAtomB,
|
| 156 |
+
ElementB)
|
| 157 |
+
: warp_idx(warp_idx_)
|
| 158 |
+
, warp_group_thread_idx(warp_group_thread_idx_) { }
|
| 159 |
+
|
| 160 |
+
template <
|
| 161 |
+
class TensorSmemB,
|
| 162 |
+
class TensorTransposedSmemB>
|
| 163 |
+
CUTLASS_DEVICE void operator()(
|
| 164 |
+
TensorSmemB const& sB,
|
| 165 |
+
TensorTransposedSmemB const& gmma_sB,
|
| 166 |
+
int read_stage, int current_step) {
|
| 167 |
+
if (current_step > 0) {
|
| 168 |
+
return;
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
constexpr int NumMathWarpGroup = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup;
|
| 172 |
+
static_assert(NumMathWarpGroup == 1 ||
|
| 173 |
+
(!detail::use_universal_transposition<SmemLayoutAtomB, ElementB>() && NumMathWarpGroup == 2),
|
| 174 |
+
"Wrong math warp group number for TransposeB");
|
| 175 |
+
constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K.
|
| 176 |
+
|
| 177 |
+
constexpr int BytesPerSmemSwizzleUnit = 16;
|
| 178 |
+
constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB);
|
| 179 |
+
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 180 |
+
/// Universal transposition, need warp_group sync between load and store.
|
| 181 |
+
/// The number of reg used depends on the input elementB.
|
| 182 |
+
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 183 |
+
/*
|
| 184 |
+
In one copy step, a warp group would load WarpgroupTileSize * WarpgroupTileSize tile then store to transposed location.
|
| 185 |
+
In warp_group_tile, each warp holds Four WarpTileSize x WarpTileSize elements:
|
| 186 |
+
K
|
| 187 |
+
------------
|
| 188 |
+
| W0 W1 W2 W3 ---
|
| 189 |
+
| W0 W1 W2 W3 |
|
| 190 |
+
| W0 W1 W2 W3 | --> Copy Step 0
|
| 191 |
+
| W0 W1 W2 W3 ---
|
| 192 |
+
....
|
| 193 |
+
| W0 W1 W2 W3 ---
|
| 194 |
+
| W0 W1 W2 W3 |
|
| 195 |
+
| W0 W1 W2 W3 | --> Copy Step n
|
| 196 |
+
| W0 W1 W2 W3 ---
|
| 197 |
+
*/
|
| 198 |
+
static_assert((NumThreadsPerWarpGroup % WarpThreadShapeN == 0), "Unsupported warp thread layout.");
|
| 199 |
+
constexpr auto WarpgroupThreadLayout = make_layout(make_shape(Int<WarpThreadShapeN>{}, Int<NumThreadsPerWarpGroup / WarpThreadShapeN>{}));
|
| 200 |
+
|
| 201 |
+
// Get copy tile and partition to each thread
|
| 202 |
+
auto sB_tiled_copy = make_tiled_copy(
|
| 203 |
+
Copy_Atom<DefaultCopy, ElementB>{},
|
| 204 |
+
WarpgroupThreadLayout, // thr_layout
|
| 205 |
+
Layout<_1>{} // val_layout
|
| 206 |
+
);
|
| 207 |
+
static_assert(size(sB_tiled_copy) == size(TiledMma{}), "Wrong thread number in TiledCopy.");
|
| 208 |
+
|
| 209 |
+
auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx);
|
| 210 |
+
Tensor tCsB = sB_thr_copy.partition_S( sB(_,_,read_stage)); // (CPY, CPY_N, CPY_K)
|
| 211 |
+
Tensor tCsB_transposed = sB_thr_copy.partition_D(gmma_sB(_,_,read_stage)); // (CPY, CPY_N, CPY_K)
|
| 212 |
+
|
| 213 |
+
// Divide partitioned tile to limit register usage
|
| 214 |
+
constexpr int CopySteps = size<0>(SmemLayoutB{}) / WarpgroupTileSize;
|
| 215 |
+
constexpr auto CopyTileShape = make_shape(size<0>(tCsB), Int< size<1>(tCsB) / CopySteps >{}, size<2>(tCsB));
|
| 216 |
+
static_assert(size<1>(tCsB) % CopySteps == 0, "CopySteps must evenly divide rank 1 size of partitioned SMEM.");
|
| 217 |
+
|
| 218 |
+
Tensor tCsB_copy_tile = zipped_divide(tCsB, CopyTileShape);
|
| 219 |
+
Tensor tCsB_copy_tile_transposed = zipped_divide(tCsB_transposed, CopyTileShape);
|
| 220 |
+
auto transpose_fragment = make_fragment_like(tCsB_copy_tile(_,_0{}));
|
| 221 |
+
|
| 222 |
+
CUTLASS_PRAGMA_NO_UNROLL
|
| 223 |
+
for (int step = 0; step < CopySteps; ++step) {
|
| 224 |
+
copy(sB_tiled_copy, tCsB_copy_tile(_,step), transpose_fragment);
|
| 225 |
+
|
| 226 |
+
// Make sure all elements are read before being overwritten
|
| 227 |
+
__syncthreads();
|
| 228 |
+
|
| 229 |
+
copy(sB_tiled_copy, transpose_fragment, tCsB_copy_tile_transposed(_,step));
|
| 230 |
+
}
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
CUTLASS_DEVICE void synchronize(int step) {
|
| 234 |
+
if (step == 0) {
|
| 235 |
+
// SMEM fence to make sure B is transposed before math
|
| 236 |
+
cutlass::arch::fence_view_async_shared();
|
| 237 |
+
cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier);
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
CUTLASS_DEVICE void synchronize() {
|
| 242 |
+
// SMEM fence to make sure B is transposed before math
|
| 243 |
+
cutlass::arch::fence_view_async_shared();
|
| 244 |
+
cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier);
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
template <
|
| 248 |
+
class TensorSmemB,
|
| 249 |
+
class TensorTransposedSmemB>
|
| 250 |
+
CUTLASS_DEVICE void transpose(
|
| 251 |
+
TensorSmemB const& sB,
|
| 252 |
+
TensorTransposedSmemB const& gmma_sB,
|
| 253 |
+
int read_stage) {
|
| 254 |
+
|
| 255 |
+
this->operator()(sB, gmma_sB, read_stage, 0);
|
| 256 |
+
synchronize();
|
| 257 |
+
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
private:
|
| 261 |
+
const int warp_idx;
|
| 262 |
+
const int warp_group_thread_idx;
|
| 263 |
+
};
|
| 264 |
+
|
| 265 |
+
template<
|
| 266 |
+
class TiledMma_,
|
| 267 |
+
class SmemLayoutB_,
|
| 268 |
+
class SmemLayoutAtomB_,
|
| 269 |
+
class ElementB_>
|
| 270 |
+
class AsyncTranspositionOperandB {
|
| 271 |
+
public:
|
| 272 |
+
|
| 273 |
+
using TiledMma = TiledMma_;
|
| 274 |
+
using SmemLayoutB = SmemLayoutB_;
|
| 275 |
+
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
| 276 |
+
using ElementB = ElementB_;
|
| 277 |
+
|
| 278 |
+
static constexpr int Steps = 2;
|
| 279 |
+
static constexpr int NumMathWarpGroup = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup;
|
| 280 |
+
static constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup;
|
| 281 |
+
static_assert(NumMathWarpGroup <= 2,
|
| 282 |
+
"Wrong math warp group number for TransposeB");
|
| 283 |
+
static constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K.
|
| 284 |
+
static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp;
|
| 285 |
+
|
| 286 |
+
static constexpr int BytesPerSmemSwizzleUnit = 16;
|
| 287 |
+
static constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB);
|
| 288 |
+
static constexpr int WarpThreadShapeK = NumThreadsPerWarp / WarpThreadShapeN;
|
| 289 |
+
static constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1);
|
| 290 |
+
|
| 291 |
+
static constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile;
|
| 292 |
+
static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invalid warp thread shape." );
|
| 293 |
+
static constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step.
|
| 294 |
+
static constexpr int64_t WarpTileNCoordLUT = 06723763275316420;
|
| 295 |
+
static constexpr int64_t WarpTileKCoordLUT = 05410541064206420;
|
| 296 |
+
static constexpr int NumStepsEncoded = 4; // Only encoding first 4 steps into LUT.
|
| 297 |
+
static constexpr int MaskPerStep = 07; // Each step is encoded into 3bits,
|
| 298 |
+
static constexpr int NumBitsPerStep = 3;
|
| 299 |
+
static constexpr int MaskPerWarp = 07777; // Each warp has 4 steps(12 bits)
|
| 300 |
+
static constexpr int NumBitsPerWarp = 12;
|
| 301 |
+
// Number of warp_group_tiles
|
| 302 |
+
static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0,
|
| 303 |
+
"Copy size must evenly divide SMEM tile.");
|
| 304 |
+
static constexpr int WarpgroupTileNum = size<0>(SmemLayoutB{}) / WarpgroupTileSize;
|
| 305 |
+
|
| 306 |
+
static_assert(size<2>(typename TiledMma::AtomShape_MNK{}) <= WarpThreadShapeK,
|
| 307 |
+
"Need to be able to transpose first k-block in the first step");
|
| 308 |
+
|
| 309 |
+
constexpr CUTLASS_HOST_DEVICE
|
| 310 |
+
AsyncTranspositionOperandB(
|
| 311 |
+
int warp_idx_,
|
| 312 |
+
int warp_group_thread_idx_,
|
| 313 |
+
TiledMma,
|
| 314 |
+
SmemLayoutB,
|
| 315 |
+
SmemLayoutAtomB,
|
| 316 |
+
ElementB)
|
| 317 |
+
: warp_idx(warp_idx_)
|
| 318 |
+
, warp_group_thread_idx(warp_group_thread_idx_)
|
| 319 |
+
, warp_idx_in_warp_group(warp_idx_ % NumWarpsPerWarpGroup)
|
| 320 |
+
, current_warp_tile_n_coord_LUT((WarpTileNCoordLUT >> ((warp_idx_
|
| 321 |
+
% NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp)
|
| 322 |
+
, current_warp_tile_k_coord_LUT((WarpTileKCoordLUT >> ((warp_idx_
|
| 323 |
+
% NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) { }
|
| 324 |
+
|
| 325 |
+
template <
|
| 326 |
+
class TensorSmemB,
|
| 327 |
+
class TensorTransposedSmemB>
|
| 328 |
+
CUTLASS_DEVICE void operator()(
|
| 329 |
+
TensorSmemB const& sB,
|
| 330 |
+
TensorTransposedSmemB const& gmma_sB,
|
| 331 |
+
int read_stage, int current_step)
|
| 332 |
+
{
|
| 333 |
+
if (current_step >= StepsPerWarpGroup) {
|
| 334 |
+
return;
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
static constexpr auto WarpThreadLayout = make_layout(make_shape(Int<WarpThreadShapeN>{}, Int<WarpThreadShapeK>{}));
|
| 338 |
+
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 339 |
+
/// A warp group uses 2 steps to transpose the whole WarpgroupTileSize x WarpgroupTileSize.
|
| 340 |
+
/// In each step, one warp would hold two warp_tiles.
|
| 341 |
+
/// Step 0: Step 1:
|
| 342 |
+
/// W0 W1 W2 W3 -- -- -- --
|
| 343 |
+
/// W1 W0 -- -- -- -- W3 W2
|
| 344 |
+
/// W2 -- -- -- -- W3 W0 W1
|
| 345 |
+
/// W3 -- -- -- -- W2 W1 W0
|
| 346 |
+
///
|
| 347 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 348 |
+
///
|
| 349 |
+
/// Fully static coord LUT to avoid extra register use.
|
| 350 |
+
/// [warp_id][step][warp_tile][n / k]
|
| 351 |
+
/// Step 0 Step 1 Step 2 Step 3 Step 4 Step 5 Step 6 Step 7
|
| 352 |
+
/// {{{0,0}, {1,1}}, {{2,2}, {3,3}}, {{4,4}, {5,5}}, {{6,6}, {7,7}}, {{4,0}, {0,4}}, {{4,1}, {1,4}}, {{4,2}, {2,4}}, {{4,3}, {3,4}}}, // W0
|
| 353 |
+
/// {{{1,0}, {0,1}}, {{3,2}, {2,3}}, {{5,4}, {4,5}}, {{7,6}, {6,7}}, {{5,0}, {0,5}}, {{5,1}, {1,5}}, {{5,2}, {2,5}}, {{5,3}, {3,5}}}, // W1
|
| 354 |
+
/// {{{2,0}, {0,2}}, {{3,1}, {1,3}}, {{6,4}, {4,6}}, {{7,5}, {5,7}}, {{6,0}, {0,6}}, {{6,1}, {1,6}}, {{6,2}, {2,6}}, {{6,3}, {3,6}}}, // W2
|
| 355 |
+
/// {{{3,0}, {0,3}}, {{2,1}, {1,2}}, {{7,4}, {4,7}}, {{6,5}, {5,6}}, {{7,0}, {0,7}}, {{7,1}, {1,7}}, {{7,2}, {2,7}}, {{7,3}, {3,7}}}, // W3
|
| 356 |
+
///
|
| 357 |
+
/// Encoding the coord of warp tile0 into two int64_t values.
|
| 358 |
+
/// Only encoding Step 0 ~ Step 4, since Step 5 ~ Step 7 have a straightforward pattern.
|
| 359 |
+
/// Only encoding warp tile0, since the coords of warp tile1 could be easily deduced from warp tile0.
|
| 360 |
+
/// The 2-step transposition and the 8-step transposition share the same encoding.
|
| 361 |
+
///
|
| 362 |
+
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 363 |
+
|
| 364 |
+
// Divide entire SMEM to multiple warp_tiles
|
| 365 |
+
constexpr auto WarpTileShape = make_shape(Int<WarpTileSize>(), Int<WarpTileSize>());
|
| 366 |
+
Tensor s_tile = zipped_divide( sB(_,_,read_stage), WarpTileShape);
|
| 367 |
+
Tensor s_tile_transposed = zipped_divide(gmma_sB(_,_,read_stage), WarpTileShape);
|
| 368 |
+
|
| 369 |
+
// Get copy tile
|
| 370 |
+
auto sB_tiled_copy = make_tiled_copy(
|
| 371 |
+
Copy_Atom<DefaultCopy, ElementB>{},
|
| 372 |
+
WarpThreadLayout, // thr_layout
|
| 373 |
+
Layout<_1>{} // val_layout
|
| 374 |
+
);
|
| 375 |
+
|
| 376 |
+
static_assert(size(sB_tiled_copy) * NumWarpsPerWarpGroup == size(TiledMma{}) / NumMathWarpGroup, "Wrong thread number in TiledCopy.");
|
| 377 |
+
auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx % NumThreadsPerWarp); // slice based on lane_idx
|
| 378 |
+
|
| 379 |
+
// Construct fragments for transposition
|
| 380 |
+
Tensor tmp_tCsB = sB_thr_copy.partition_S(flatten(s_tile(_, make_coord(_0{}, _0{}))));
|
| 381 |
+
decltype(make_fragment_like(tmp_tCsB)) transpose_fragments[TilesPerWarp] = {
|
| 382 |
+
make_fragment_like(tmp_tCsB),
|
| 383 |
+
make_fragment_like(tmp_tCsB)
|
| 384 |
+
};
|
| 385 |
+
|
| 386 |
+
[[maybe_unused]] int step = current_step * NumMathWarpGroup;
|
| 387 |
+
if constexpr (NumMathWarpGroup == 2) {
|
| 388 |
+
// For 2 math warpgroup, warp idx4~7 is 1st warp group and 8~9 is 2nd, so decide if 2nd warpgroup need warp idx divide 8.
|
| 389 |
+
step += warp_idx / (NumWarpsPerWarpGroup * 2);
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
int tmp_warp_tile_n_coord_LUT = current_warp_tile_n_coord_LUT >> (NumBitsPerStep * current_step);
|
| 393 |
+
int tmp_warp_tile_k_coord_LUT = current_warp_tile_k_coord_LUT >> (NumBitsPerStep * current_step);
|
| 394 |
+
|
| 395 |
+
if constexpr (NumMathWarpGroup == 2) {
|
| 396 |
+
tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2));
|
| 397 |
+
tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2));
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
// decoding the warp tile coord.
|
| 401 |
+
int warp_tile0_n, warp_tile0_k;
|
| 402 |
+
if constexpr (StepsPerWarpGroup <= NumStepsEncoded) {
|
| 403 |
+
warp_tile0_n = tmp_warp_tile_n_coord_LUT & MaskPerStep;
|
| 404 |
+
warp_tile0_k = tmp_warp_tile_k_coord_LUT & MaskPerStep;
|
| 405 |
+
} else {
|
| 406 |
+
warp_tile0_n = step < NumStepsEncoded ? (tmp_warp_tile_n_coord_LUT & MaskPerStep) : 4 + warp_idx_in_warp_group;
|
| 407 |
+
warp_tile0_k = step < NumStepsEncoded ? (tmp_warp_tile_k_coord_LUT & MaskPerStep) : step - 4;
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
int warp_tile1_n = warp_tile0_n == warp_tile0_k ? warp_tile0_n + 1 : warp_tile0_k;
|
| 411 |
+
int warp_tile1_k = warp_tile0_n == warp_tile0_k ? warp_tile0_k + 1 : warp_tile0_n;
|
| 412 |
+
|
| 413 |
+
CUTLASS_PRAGMA_UNROLL
|
| 414 |
+
for (int warp_group_tile = 0; warp_group_tile < WarpgroupTileNum; ++warp_group_tile) {
|
| 415 |
+
|
| 416 |
+
static_assert(TilesPerWarp == 2);
|
| 417 |
+
|
| 418 |
+
// [warp_tile][n/k]
|
| 419 |
+
const int warp_tile_coord[TilesPerWarp][2] = {
|
| 420 |
+
// n k
|
| 421 |
+
{warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile0_n, warp_tile0_k}, // warp_tile 0
|
| 422 |
+
{warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile1_n, warp_tile1_k} // warp_tile 1
|
| 423 |
+
};
|
| 424 |
+
|
| 425 |
+
CUTLASS_PRAGMA_UNROLL
|
| 426 |
+
for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) {
|
| 427 |
+
Tensor tCsB = sB_thr_copy.partition_S(
|
| 428 |
+
flatten(s_tile(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1])))
|
| 429 |
+
); // (CPY, CPY_N, CPY_K)
|
| 430 |
+
|
| 431 |
+
copy(sB_tiled_copy, tCsB, transpose_fragments[warp_tile]);
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
// Make sure elements in two 8x8 warp tiles are all consumed
|
| 435 |
+
__syncwarp();
|
| 436 |
+
|
| 437 |
+
CUTLASS_PRAGMA_UNROLL
|
| 438 |
+
for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) {
|
| 439 |
+
Tensor tCsB_transposed = sB_thr_copy.partition_D(
|
| 440 |
+
flatten(s_tile_transposed(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1])))
|
| 441 |
+
); // (CPY, CPY_N, CPY_K)
|
| 442 |
+
copy(sB_tiled_copy, transpose_fragments[warp_tile], tCsB_transposed);
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
} // loop warp_group_tile
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
CUTLASS_DEVICE void synchronize(int step) {
|
| 449 |
+
if (step < StepsPerWarpGroup) {
|
| 450 |
+
// SMEM fence to make sure B is transposed before math
|
| 451 |
+
cutlass::arch::fence_view_async_shared();
|
| 452 |
+
cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier);
|
| 453 |
+
}
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
CUTLASS_DEVICE void synchronize() {
|
| 457 |
+
cutlass::arch::fence_view_async_shared();
|
| 458 |
+
cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier);
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
template <
|
| 462 |
+
class TensorSmemB,
|
| 463 |
+
class TensorTransposedSmemB>
|
| 464 |
+
CUTLASS_DEVICE void transpose(
|
| 465 |
+
TensorSmemB const& sB,
|
| 466 |
+
TensorTransposedSmemB const& gmma_sB,
|
| 467 |
+
int read_stage) {
|
| 468 |
+
|
| 469 |
+
CUTLASS_PRAGMA_UNROLL
|
| 470 |
+
for(int i = 0; i < StepsPerWarpGroup; ++i) {
|
| 471 |
+
this->operator()(sB, gmma_sB, read_stage, i);
|
| 472 |
+
}
|
| 473 |
+
synchronize();
|
| 474 |
+
|
| 475 |
+
}
|
| 476 |
+
private:
|
| 477 |
+
const int warp_idx;
|
| 478 |
+
const int warp_group_thread_idx;
|
| 479 |
+
const int warp_idx_in_warp_group;
|
| 480 |
+
const int current_warp_tile_n_coord_LUT;
|
| 481 |
+
const int current_warp_tile_k_coord_LUT;
|
| 482 |
+
};
|
| 483 |
+
|
| 484 |
+
template<
|
| 485 |
+
class TiledMma_,
|
| 486 |
+
class SmemLayoutB_,
|
| 487 |
+
class SmemLayoutAtomB_,
|
| 488 |
+
class ElementB_>
|
| 489 |
+
class AsyncTranspositionOperandB_1BElementB {
|
| 490 |
+
public:
|
| 491 |
+
|
| 492 |
+
static_assert(sizeof(ElementB_) == 1);
|
| 493 |
+
|
| 494 |
+
using TiledMma = TiledMma_;
|
| 495 |
+
using SmemLayoutB = SmemLayoutB_;
|
| 496 |
+
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
| 497 |
+
using ElementB = ElementB_;
|
| 498 |
+
|
| 499 |
+
static constexpr int Steps = 8;
|
| 500 |
+
static constexpr int NumMathWarpGroup = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup;
|
| 501 |
+
static constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup;
|
| 502 |
+
static_assert(NumMathWarpGroup <= 2,
|
| 503 |
+
"Wrong math warp group number for TransposeB");
|
| 504 |
+
static constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K.
|
| 505 |
+
static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp;
|
| 506 |
+
|
| 507 |
+
static constexpr int BytesPerSmemSwizzleUnit = 16;
|
| 508 |
+
static constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB);
|
| 509 |
+
static constexpr int WarpThreadShapeK = NumThreadsPerWarp / WarpThreadShapeN;
|
| 510 |
+
static constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1);
|
| 511 |
+
|
| 512 |
+
static constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile;
|
| 513 |
+
static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invalid warp thread shape." );
|
| 514 |
+
static constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step.
|
| 515 |
+
static constexpr int64_t WarpTileNCoordLUT = 06723763275316420;
|
| 516 |
+
static constexpr int64_t WarpTileKCoordLUT = 05410541064206420;
|
| 517 |
+
static constexpr int NumStepsEncoded = 4; // Only encoding first 4 steps into LUT.
|
| 518 |
+
static constexpr int MaskPerStep = 07; // Each step is encoded into 3bits,
|
| 519 |
+
static constexpr int NumBitsPerStep = 3;
|
| 520 |
+
static constexpr int MaskPerWarp = 07777; // Each warp has 4 steps(12 bits)
|
| 521 |
+
static constexpr int NumBitsPerWarp = 12;
|
| 522 |
+
// Number of warp_group_tiles
|
| 523 |
+
static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0,
|
| 524 |
+
"Copy size must evenly divide SMEM tile.");
|
| 525 |
+
static constexpr int WarpgroupTileNum = size<0>(SmemLayoutB{}) / WarpgroupTileSize;
|
| 526 |
+
|
| 527 |
+
constexpr CUTLASS_HOST_DEVICE
|
| 528 |
+
AsyncTranspositionOperandB_1BElementB(
|
| 529 |
+
int warp_idx_,
|
| 530 |
+
int warp_group_thread_idx_,
|
| 531 |
+
TiledMma,
|
| 532 |
+
SmemLayoutB,
|
| 533 |
+
SmemLayoutAtomB,
|
| 534 |
+
ElementB)
|
| 535 |
+
: warp_idx(warp_idx_)
|
| 536 |
+
, warp_group_thread_idx(warp_group_thread_idx_)
|
| 537 |
+
, warp_idx_in_warp_group(warp_idx_ % NumWarpsPerWarpGroup)
|
| 538 |
+
, current_warp_tile_n_coord_LUT((WarpTileNCoordLUT >> ((warp_idx_
|
| 539 |
+
% NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp)
|
| 540 |
+
, current_warp_tile_k_coord_LUT((WarpTileKCoordLUT >> ((warp_idx_
|
| 541 |
+
% NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) { }
|
| 542 |
+
|
| 543 |
+
template <
|
| 544 |
+
class TensorSmemB,
|
| 545 |
+
class TensorTransposedSmemB>
|
| 546 |
+
CUTLASS_DEVICE void operator()(
|
| 547 |
+
TensorSmemB const& sB,
|
| 548 |
+
TensorTransposedSmemB const& gmma_sB,
|
| 549 |
+
int read_stage, int current_step)
|
| 550 |
+
{
|
| 551 |
+
if (current_step > 0) {
|
| 552 |
+
return;
|
| 553 |
+
}
|
| 554 |
+
|
| 555 |
+
constexpr auto WarpThreadLayout = make_layout(make_shape(Int<WarpThreadShapeN>{}, Int<WarpThreadShapeK>{}));
|
| 556 |
+
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 557 |
+
/// A warp group uses 8 steps to transpose the whole WarpgroupTileSize x WarpgroupTileSize.
|
| 558 |
+
/// Divide a warp_group_tile into 8x8 warp_tiles to further reduce the reg usage.
|
| 559 |
+
/// Step 0: Step 1: Step 2: Step 3:
|
| 560 |
+
/// W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
|
| 561 |
+
/// W1 W0 -- -- -- -- -- -- -- -- W3 W2 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
|
| 562 |
+
/// W2 -- -- -- -- -- -- -- -- W3 W0 W1 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
|
| 563 |
+
/// W3 -- -- -- -- -- -- -- -- W2 W1 W0 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
|
| 564 |
+
/// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- --
|
| 565 |
+
/// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W1 W0 -- -- -- -- -- -- -- -- W3 W2
|
| 566 |
+
/// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W2 -- -- -- -- -- -- -- -- W3 W0 W1
|
| 567 |
+
/// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W3 -- -- -- -- -- -- -- -- W2 W1 W0
|
| 568 |
+
///
|
| 569 |
+
/// Step 4: Step 5: Step 6: Step 7:
|
| 570 |
+
/// -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
|
| 571 |
+
/// -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --
|
| 572 |
+
/// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- --
|
| 573 |
+
/// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3
|
| 574 |
+
/// W0 -- -- -- -- -- -- -- -- W0 -- -- -- -- -- -- -- -- W0 -- -- -- -- -- -- -- -- W0 -- -- -- --
|
| 575 |
+
/// W1 -- -- -- -- -- -- -- -- W1 -- -- -- -- -- -- -- -- W1 -- -- -- -- -- -- -- -- W1 -- -- -- --
|
| 576 |
+
/// W2 -- -- -- -- -- -- -- -- W2 -- -- -- -- -- -- -- -- W2 -- -- -- -- -- -- -- -- W2 -- -- -- --
|
| 577 |
+
/// W3 -- -- -- -- -- -- -- -- W3 -- -- -- -- -- -- -- -- W3 -- -- -- -- -- -- -- -- W3 -- -- -- --
|
| 578 |
+
///
|
| 579 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 580 |
+
///
|
| 581 |
+
/// Fully static coord LUT to avoid extra register use.
|
| 582 |
+
/// [warp_id][step][warp_tile][n / k]
|
| 583 |
+
/// Step 0 Step 1 Step 2 Step 3 Step 4 Step 5 Step 6 Step 7
|
| 584 |
+
/// {{{0,0}, {1,1}}, {{2,2}, {3,3}}, {{4,4}, {5,5}}, {{6,6}, {7,7}}, {{4,0}, {0,4}}, {{4,1}, {1,4}}, {{4,2}, {2,4}}, {{4,3}, {3,4}}}, // W0
|
| 585 |
+
/// {{{1,0}, {0,1}}, {{3,2}, {2,3}}, {{5,4}, {4,5}}, {{7,6}, {6,7}}, {{5,0}, {0,5}}, {{5,1}, {1,5}}, {{5,2}, {2,5}}, {{5,3}, {3,5}}}, // W1
|
| 586 |
+
/// {{{2,0}, {0,2}}, {{3,1}, {1,3}}, {{6,4}, {4,6}}, {{7,5}, {5,7}}, {{6,0}, {0,6}}, {{6,1}, {1,6}}, {{6,2}, {2,6}}, {{6,3}, {3,6}}}, // W2
|
| 587 |
+
/// {{{3,0}, {0,3}}, {{2,1}, {1,2}}, {{7,4}, {4,7}}, {{6,5}, {5,6}}, {{7,0}, {0,7}}, {{7,1}, {1,7}}, {{7,2}, {2,7}}, {{7,3}, {3,7}}}, // W3
|
| 588 |
+
///
|
| 589 |
+
/// Encoding the coord of warp tile0 into two int64_t values.
|
| 590 |
+
/// Only encoding Step 0 ~ Step 4, since Step 5 ~ Step 7 have a straightforward pattern.
|
| 591 |
+
/// Only encoding warp tile0, since the coords of warp tile1 could be easily deduced from warp tile0.
|
| 592 |
+
/// The 2-step transposition and the 8-step transposition share the same encoding.
|
| 593 |
+
///
|
| 594 |
+
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 595 |
+
|
| 596 |
+
// Divide entire SMEM to multiple warp_tiles
|
| 597 |
+
constexpr auto WarpTileShape = make_shape(Int<WarpTileSize>(), Int<WarpTileSize>());
|
| 598 |
+
Tensor s_tile = zipped_divide( sB(_,_,read_stage), WarpTileShape);
|
| 599 |
+
Tensor s_tile_transposed = zipped_divide(gmma_sB(_,_,read_stage), WarpTileShape);
|
| 600 |
+
|
| 601 |
+
// Get copy tile
|
| 602 |
+
auto sB_tiled_copy = make_tiled_copy(
|
| 603 |
+
Copy_Atom<DefaultCopy, ElementB>{},
|
| 604 |
+
WarpThreadLayout, // thr_layout
|
| 605 |
+
Layout<_1>{} // val_layout
|
| 606 |
+
);
|
| 607 |
+
static_assert(size(sB_tiled_copy) * NumWarpsPerWarpGroup == size(TiledMma{}) / NumMathWarpGroup, "Wrong thread number in TiledCopy.");
|
| 608 |
+
auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx % NumThreadsPerWarp); // slice based on lane_idx
|
| 609 |
+
|
| 610 |
+
// Construct fragments for transposition
|
| 611 |
+
Tensor tmp_tCsB = sB_thr_copy.partition_S(flatten(s_tile(_, make_coord(_0{}, _0{}))));
|
| 612 |
+
decltype(make_fragment_like(tmp_tCsB)) transpose_fragments[TilesPerWarp] = {
|
| 613 |
+
make_fragment_like(tmp_tCsB),
|
| 614 |
+
make_fragment_like(tmp_tCsB)
|
| 615 |
+
};
|
| 616 |
+
|
| 617 |
+
CUTLASS_PRAGMA_NO_UNROLL
|
| 618 |
+
for (int warp_group_tile = 0; warp_group_tile < WarpgroupTileNum; ++warp_group_tile) {
|
| 619 |
+
int tmp_warp_tile_n_coord_LUT = current_warp_tile_n_coord_LUT;
|
| 620 |
+
int tmp_warp_tile_k_coord_LUT = current_warp_tile_k_coord_LUT;
|
| 621 |
+
constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup;
|
| 622 |
+
|
| 623 |
+
if constexpr (NumMathWarpGroup == 2) {
|
| 624 |
+
tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2));
|
| 625 |
+
tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2));
|
| 626 |
+
}
|
| 627 |
+
|
| 628 |
+
CUTLASS_PRAGMA_NO_UNROLL
|
| 629 |
+
for (int step_per_warp_group = 0; step_per_warp_group < StepsPerWarpGroup; ++step_per_warp_group) {
|
| 630 |
+
// For 2 math warpgroup, warp idx4~7 is 1st warp group and 8~9 is 2nd, so decide if 2nd warpgroup need warp idx divide 8.
|
| 631 |
+
int step = step_per_warp_group * NumMathWarpGroup + warp_idx / (NumWarpsPerWarpGroup * 2);
|
| 632 |
+
// decoding the warp tile coord.
|
| 633 |
+
int warp_tile0_n = step < NumStepsEncoded ? (tmp_warp_tile_n_coord_LUT & MaskPerStep) : 4 + warp_idx_in_warp_group;
|
| 634 |
+
int warp_tile0_k = step < NumStepsEncoded ? (tmp_warp_tile_k_coord_LUT & MaskPerStep) : step - 4;
|
| 635 |
+
int warp_tile1_n = warp_tile0_n == warp_tile0_k ? warp_tile0_n + 1 : warp_tile0_k;
|
| 636 |
+
int warp_tile1_k = warp_tile0_n == warp_tile0_k ? warp_tile0_k + 1 : warp_tile0_n;
|
| 637 |
+
|
| 638 |
+
tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep;
|
| 639 |
+
tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep;
|
| 640 |
+
|
| 641 |
+
static_assert(TilesPerWarp == 2);
|
| 642 |
+
|
| 643 |
+
// [warp_tile][n/k]
|
| 644 |
+
const int warp_tile_coord[TilesPerWarp][2] = {
|
| 645 |
+
// n k
|
| 646 |
+
{warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile0_n, warp_tile0_k}, // warp_tile 0
|
| 647 |
+
{warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile1_n, warp_tile1_k} // warp_tile 1
|
| 648 |
+
};
|
| 649 |
+
|
| 650 |
+
CUTLASS_PRAGMA_UNROLL
|
| 651 |
+
for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) {
|
| 652 |
+
Tensor tCsB = sB_thr_copy.partition_S(
|
| 653 |
+
flatten(s_tile(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1])))
|
| 654 |
+
); // (CPY, CPY_N, CPY_K)
|
| 655 |
+
|
| 656 |
+
copy(sB_tiled_copy, tCsB, transpose_fragments[warp_tile]);
|
| 657 |
+
}
|
| 658 |
+
|
| 659 |
+
// Make sure elements in two 8x8 warp tiles are all consumed
|
| 660 |
+
__syncwarp();
|
| 661 |
+
|
| 662 |
+
CUTLASS_PRAGMA_UNROLL
|
| 663 |
+
for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) {
|
| 664 |
+
Tensor tCsB_transposed = sB_thr_copy.partition_D(
|
| 665 |
+
flatten(s_tile_transposed(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1])))
|
| 666 |
+
); // (CPY, CPY_N, CPY_K)
|
| 667 |
+
copy(sB_tiled_copy, transpose_fragments[warp_tile], tCsB_transposed);
|
| 668 |
+
}
|
| 669 |
+
} // lock step
|
| 670 |
+
} // loop warp_group_tile
|
| 671 |
+
}
|
| 672 |
+
|
| 673 |
+
CUTLASS_DEVICE void synchronize(int step) {
|
| 674 |
+
if (step == 0) {
|
| 675 |
+
// SMEM fence to make sure B is transposed before math
|
| 676 |
+
cutlass::arch::fence_view_async_shared();
|
| 677 |
+
cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier);
|
| 678 |
+
}
|
| 679 |
+
}
|
| 680 |
+
|
| 681 |
+
CUTLASS_DEVICE void synchronize() {
|
| 682 |
+
cutlass::arch::fence_view_async_shared();
|
| 683 |
+
cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier);
|
| 684 |
+
}
|
| 685 |
+
|
| 686 |
+
template <
|
| 687 |
+
class TensorSmemB,
|
| 688 |
+
class TensorTransposedSmemB>
|
| 689 |
+
CUTLASS_DEVICE void transpose(
|
| 690 |
+
TensorSmemB const& sB,
|
| 691 |
+
TensorTransposedSmemB const& gmma_sB,
|
| 692 |
+
int read_stage) {
|
| 693 |
+
this->operator()(sB, gmma_sB, read_stage, 0);
|
| 694 |
+
synchronize();
|
| 695 |
+
}
|
| 696 |
+
|
| 697 |
+
private:
|
| 698 |
+
const int warp_idx;
|
| 699 |
+
const int warp_group_thread_idx;
|
| 700 |
+
const int warp_idx_in_warp_group;
|
| 701 |
+
const int current_warp_tile_n_coord_LUT;
|
| 702 |
+
const int current_warp_tile_k_coord_LUT;
|
| 703 |
+
};
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
template<
|
| 707 |
+
class TiledMma,
|
| 708 |
+
class SmemLayoutB,
|
| 709 |
+
class SmemLayoutAtomB,
|
| 710 |
+
class ElementB,
|
| 711 |
+
bool TransposeB
|
| 712 |
+
>
|
| 713 |
+
constexpr CUTLASS_HOST_DEVICE
|
| 714 |
+
auto
|
| 715 |
+
make_transpose_operand_b(
|
| 716 |
+
int warp_idx,
|
| 717 |
+
int warp_group_thread_idx,
|
| 718 |
+
TiledMma,
|
| 719 |
+
SmemLayoutB,
|
| 720 |
+
SmemLayoutAtomB,
|
| 721 |
+
ElementB,
|
| 722 |
+
cute::bool_constant<TransposeB>)
|
| 723 |
+
{
|
| 724 |
+
if constexpr (!TransposeB) {
|
| 725 |
+
return NoTranspositionOperandB(
|
| 726 |
+
warp_idx, warp_group_thread_idx, TiledMma{},
|
| 727 |
+
SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{});
|
| 728 |
+
}
|
| 729 |
+
else if constexpr (use_universal_transposition<SmemLayoutAtomB, ElementB>()) {
|
| 730 |
+
return UniversalTranspositionOperandB(
|
| 731 |
+
warp_idx, warp_group_thread_idx, TiledMma{},
|
| 732 |
+
SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{});
|
| 733 |
+
}
|
| 734 |
+
else if constexpr (sizeof(ElementB) == 1) {
|
| 735 |
+
return AsyncTranspositionOperandB_1BElementB(
|
| 736 |
+
warp_idx, warp_group_thread_idx, TiledMma{},
|
| 737 |
+
SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{});
|
| 738 |
+
}
|
| 739 |
+
else {
|
| 740 |
+
return AsyncTranspositionOperandB(
|
| 741 |
+
warp_idx, warp_group_thread_idx, TiledMma{},
|
| 742 |
+
SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{});
|
| 743 |
+
}
|
| 744 |
+
}
|
| 745 |
+
|
| 746 |
+
}; // namespace detail
|
| 747 |
+
|
| 748 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 749 |
+
|
| 750 |
+
} // namespace collective
|
| 751 |
+
} // namespace transform
|
| 752 |
+
} // namespace cutlass
|
| 753 |
+
|
| 754 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/device/transform_universal_adapter.hpp
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Transform Kernel Universal adapter
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
// common
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#include "cutlass/device_kernel.h"
|
| 40 |
+
#include "cutlass/gemm/gemm.h"
|
| 41 |
+
#include "cutlass/detail/layout.hpp"
|
| 42 |
+
#include "cutlass/detail/mma.hpp"
|
| 43 |
+
#include "cutlass/cuda_host_adapter.hpp"
|
| 44 |
+
|
| 45 |
+
#include "cutlass/kernel_launch.h"
|
| 46 |
+
#if !defined(__CUDACC_RTC__)
|
| 47 |
+
#include "cutlass/cluster_launch.hpp"
|
| 48 |
+
#include "cutlass/trace.h"
|
| 49 |
+
#endif // !defined(__CUDACC_RTC__)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
namespace cutlass::transform::device {
|
| 55 |
+
|
| 56 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 57 |
+
|
| 58 |
+
template <class TransformKernel_>
|
| 59 |
+
class TransformUniversalAdapter
|
| 60 |
+
{
|
| 61 |
+
public:
|
| 62 |
+
using TransformKernel = GetUnderlyingKernel_t<TransformKernel_>;
|
| 63 |
+
using Arguments = typename TransformKernel::Arguments;
|
| 64 |
+
using Params = typename TransformKernel::Params;
|
| 65 |
+
static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER;
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
private:
|
| 69 |
+
|
| 70 |
+
/// Kernel API parameters object
|
| 71 |
+
Params params_;
|
| 72 |
+
|
| 73 |
+
public:
|
| 74 |
+
|
| 75 |
+
/// Access the Params structure
|
| 76 |
+
Params const& params() const {
|
| 77 |
+
return params_;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
/// Determines whether the GEMM can execute the given problem.
|
| 81 |
+
static Status
|
| 82 |
+
can_implement(Arguments const& args) {
|
| 83 |
+
return TransformKernel::can_implement(args);
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
/// Gets the workspace size
|
| 87 |
+
static size_t
|
| 88 |
+
get_workspace_size(Arguments const& args) {
|
| 89 |
+
size_t workspace_bytes = 0;
|
| 90 |
+
workspace_bytes += TransformKernel::get_workspace_size(args);
|
| 91 |
+
|
| 92 |
+
CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes);
|
| 93 |
+
|
| 94 |
+
return workspace_bytes;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
/// Computes the grid shape
|
| 98 |
+
static dim3
|
| 99 |
+
get_grid_shape(Arguments const& args, void* workspace = nullptr) {
|
| 100 |
+
auto tmp_params = TransformKernel::to_underlying_arguments(args, workspace);
|
| 101 |
+
return TransformKernel::get_grid_shape(tmp_params);
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
/// Computes the grid shape
|
| 105 |
+
static dim3
|
| 106 |
+
get_grid_shape(Params const& params) {
|
| 107 |
+
return TransformKernel::get_grid_shape(params);
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
/// Initializes GEMM state from arguments.
|
| 112 |
+
Status
|
| 113 |
+
initialize(
|
| 114 |
+
Arguments const& args,
|
| 115 |
+
void* workspace = nullptr,
|
| 116 |
+
cudaStream_t stream = nullptr,
|
| 117 |
+
CudaHostAdapter* cuda_adapter = nullptr) {
|
| 118 |
+
|
| 119 |
+
CUTLASS_TRACE_HOST("TransformUniversalAdapter::initialize() - workspace "
|
| 120 |
+
<< workspace << ", stream: " << (stream ? "non-null" : "null")
|
| 121 |
+
<< ", EnableCudaHostAdapter: " << (kEnableCudaHostAdapter ? "True" : "false"));
|
| 122 |
+
|
| 123 |
+
// Initialize the workspace
|
| 124 |
+
Status status = TransformKernel::initialize_workspace(args, workspace, stream, cuda_adapter);
|
| 125 |
+
if (status != Status::kSuccess) {
|
| 126 |
+
return status;
|
| 127 |
+
}
|
| 128 |
+
// Initialize the Params structure
|
| 129 |
+
params_ = TransformKernel::to_underlying_arguments(args, workspace);
|
| 130 |
+
// Don't set the function attributes - require the CudaHostAdapter to set it.
|
| 131 |
+
if constexpr (kEnableCudaHostAdapter) {
|
| 132 |
+
CUTLASS_ASSERT(cuda_adapter);
|
| 133 |
+
return Status::kSuccess;
|
| 134 |
+
}
|
| 135 |
+
else {
|
| 136 |
+
//
|
| 137 |
+
// Account for dynamic smem capacity if needed
|
| 138 |
+
//
|
| 139 |
+
int smem_size = TransformKernel::SharedStorageSize;
|
| 140 |
+
|
| 141 |
+
CUTLASS_ASSERT(cuda_adapter == nullptr);
|
| 142 |
+
|
| 143 |
+
if (smem_size >= (48 << 10)) {
|
| 144 |
+
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
| 145 |
+
cudaError_t result = cudaFuncSetAttribute(
|
| 146 |
+
device_kernel<TransformKernel>,
|
| 147 |
+
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
| 148 |
+
smem_size);
|
| 149 |
+
if (cudaSuccess != result) {
|
| 150 |
+
result = cudaGetLastError(); // to clear the error bit
|
| 151 |
+
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
|
| 152 |
+
return Status::kErrorInternal;
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
return Status::kSuccess;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
static Status
|
| 160 |
+
run(Params& params,
|
| 161 |
+
cudaStream_t stream = nullptr,
|
| 162 |
+
CudaHostAdapter *cuda_adapter = nullptr,
|
| 163 |
+
int32_t kernel_index = 0,
|
| 164 |
+
bool launch_with_pdl = false) {
|
| 165 |
+
CUTLASS_TRACE_HOST("TransformUniversalAdapter::run()");
|
| 166 |
+
dim3 const block = TransformKernel::get_block_shape();
|
| 167 |
+
dim3 const grid = get_grid_shape(params);
|
| 168 |
+
|
| 169 |
+
// configure smem size and carveout
|
| 170 |
+
int smem_size = TransformKernel::SharedStorageSize;
|
| 171 |
+
|
| 172 |
+
Status launch_result{ Status::kSuccess };
|
| 173 |
+
// Use extended launch API only for mainloops that use it
|
| 174 |
+
if constexpr (TransformKernel::ArchTag::kMinComputeCapability >= 90) {
|
| 175 |
+
// Currently only support 1x1x1 for transform kernel.
|
| 176 |
+
dim3 const cluster = {1,1,1};
|
| 177 |
+
void* kernel_params[] = {¶ms};
|
| 178 |
+
|
| 179 |
+
if constexpr (kEnableCudaHostAdapter) {
|
| 180 |
+
//
|
| 181 |
+
// Use the cuda host adapter
|
| 182 |
+
//
|
| 183 |
+
CUTLASS_ASSERT(cuda_adapter);
|
| 184 |
+
if (cuda_adapter) {
|
| 185 |
+
|
| 186 |
+
if (launch_with_pdl) {
|
| 187 |
+
CUTLASS_TRACE_HOST(
|
| 188 |
+
"TransformUniversalAdapter::run() does not support launching with PDL and a custom cuda adapter.");
|
| 189 |
+
return Status::kErrorInternal;
|
| 190 |
+
}
|
| 191 |
+
launch_result = cuda_adapter->launch(grid,
|
| 192 |
+
cluster,
|
| 193 |
+
block,
|
| 194 |
+
smem_size,
|
| 195 |
+
stream,
|
| 196 |
+
kernel_params,
|
| 197 |
+
kernel_index);
|
| 198 |
+
CUTLASS_TRACE_HOST("Kernel Launch Result" << cutlassGetStatusString(launch_result));
|
| 199 |
+
}
|
| 200 |
+
else {
|
| 201 |
+
return Status::kErrorInternal;
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
else {
|
| 205 |
+
CUTLASS_ASSERT(cuda_adapter == nullptr);
|
| 206 |
+
void const* kernel = (void const*) device_kernel<TransformKernel>;
|
| 207 |
+
if constexpr (TransformKernel::ArchTag::kMinComputeCapability == 90) {
|
| 208 |
+
launch_result = ClusterLauncher::launch(
|
| 209 |
+
grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl);
|
| 210 |
+
}
|
| 211 |
+
}
|
| 212 |
+
}
|
| 213 |
+
else {
|
| 214 |
+
launch_result = Status::kSuccess;
|
| 215 |
+
cutlass::arch::synclog_setup();
|
| 216 |
+
|
| 217 |
+
if constexpr (kEnableCudaHostAdapter) {
|
| 218 |
+
CUTLASS_ASSERT(cuda_adapter);
|
| 219 |
+
if (cuda_adapter) {
|
| 220 |
+
void* kernel_params[] = {¶ms};
|
| 221 |
+
|
| 222 |
+
launch_result = cuda_adapter->launch(
|
| 223 |
+
grid, block, smem_size, stream, kernel_params, 0
|
| 224 |
+
);
|
| 225 |
+
|
| 226 |
+
}
|
| 227 |
+
else {
|
| 228 |
+
return Status::kErrorInternal;
|
| 229 |
+
}
|
| 230 |
+
}
|
| 231 |
+
else {
|
| 232 |
+
CUTLASS_ASSERT(cuda_adapter == nullptr);
|
| 233 |
+
cutlass::kernel_launch<TransformKernel>(grid, block, smem_size, stream, params, launch_with_pdl);
|
| 234 |
+
}
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
cudaError_t result = cudaGetLastError();
|
| 238 |
+
if (cudaSuccess == result && Status::kSuccess == launch_result) {
|
| 239 |
+
return Status::kSuccess;
|
| 240 |
+
}
|
| 241 |
+
else if (cudaSuccess != result) {
|
| 242 |
+
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << cudaGetErrorString(result));
|
| 243 |
+
}
|
| 244 |
+
else if (Status::kSuccess != launch_result) {
|
| 245 |
+
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << cutlassGetStatusString(launch_result));
|
| 246 |
+
}
|
| 247 |
+
return Status::kErrorInternal;
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
//
|
| 251 |
+
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
| 252 |
+
//
|
| 253 |
+
|
| 254 |
+
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
| 255 |
+
Status
|
| 256 |
+
run(
|
| 257 |
+
Arguments const& args,
|
| 258 |
+
void* workspace = nullptr,
|
| 259 |
+
cudaStream_t stream = nullptr,
|
| 260 |
+
CudaHostAdapter *cuda_adapter = nullptr,
|
| 261 |
+
int32_t kernel_index = 0,
|
| 262 |
+
bool launch_with_pdl = false
|
| 263 |
+
) {
|
| 264 |
+
Status status = initialize(args, workspace, stream, cuda_adapter);
|
| 265 |
+
|
| 266 |
+
if (Status::kSuccess == status) {
|
| 267 |
+
status = run(params_, stream, cuda_adapter, kernel_index, launch_with_pdl);
|
| 268 |
+
}
|
| 269 |
+
return status;
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
| 273 |
+
Status
|
| 274 |
+
operator()(
|
| 275 |
+
Arguments const& args,
|
| 276 |
+
void* workspace = nullptr,
|
| 277 |
+
cudaStream_t stream = nullptr,
|
| 278 |
+
CudaHostAdapter *cuda_adapter = nullptr,
|
| 279 |
+
bool launch_with_pdl = false) {
|
| 280 |
+
return run(args, workspace, stream, cuda_adapter, 0 /*kernel_index*/, launch_with_pdl);
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
| 284 |
+
Status
|
| 285 |
+
run(
|
| 286 |
+
cudaStream_t stream = nullptr,
|
| 287 |
+
CudaHostAdapter *cuda_adapter = nullptr,
|
| 288 |
+
bool launch_with_pdl = false) {
|
| 289 |
+
return run(params_, stream, cuda_adapter, 0 /*kernel_index*/, launch_with_pdl);
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
| 293 |
+
Status
|
| 294 |
+
operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, bool launch_with_pdl = false) {
|
| 295 |
+
return run(params_, stream, cuda_adapter, 0 /*kernel_index*/, launch_with_pdl);
|
| 296 |
+
}
|
| 297 |
+
};
|
| 298 |
+
|
| 299 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 300 |
+
|
| 301 |
+
} // namespace cutlass::transform::device
|
| 302 |
+
|
| 303 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/filter_format_transformer.hpp
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/* \file
|
| 33 |
+
\brief Convolution filter format transformation kernel.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include <algorithm>
|
| 39 |
+
#include <random>
|
| 40 |
+
|
| 41 |
+
#include "cutlass/coord.h"
|
| 42 |
+
#include "cutlass/arch/arch.h"
|
| 43 |
+
#include "cutlass/layout/matrix.h"
|
| 44 |
+
#include "cutlass/cuda_host_adapter.hpp"
|
| 45 |
+
|
| 46 |
+
#include "cute/int_tuple.hpp"
|
| 47 |
+
#include "cute/tensor.hpp"
|
| 48 |
+
#include "cute/config.hpp"
|
| 49 |
+
|
| 50 |
+
namespace cutlass::transform::kernel {
|
| 51 |
+
|
| 52 |
+
using namespace cute;
|
| 53 |
+
|
| 54 |
+
enum class FilterFormat {
|
| 55 |
+
CKTRS,
|
| 56 |
+
CTRSK,
|
| 57 |
+
KTRSC
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
template <
|
| 61 |
+
FilterFormat SrcFormat,
|
| 62 |
+
FilterFormat DstFormat,
|
| 63 |
+
int NumDimensions,
|
| 64 |
+
class Element_,
|
| 65 |
+
int AlignmentBytes = 16
|
| 66 |
+
>
|
| 67 |
+
struct ConvFilterFormatTransformer {
|
| 68 |
+
|
| 69 |
+
using Element = Element_;
|
| 70 |
+
static_assert(SrcFormat == FilterFormat::CKTRS, "Currently only source format of CKTRS is supported");
|
| 71 |
+
static_assert(DstFormat == FilterFormat::CTRSK || DstFormat == FilterFormat::KTRSC, "Currently only destination format of CTRSK/KTRSC is supported");
|
| 72 |
+
static_assert(AlignmentBytes > 0 && AlignmentBytes % static_cast<int>(sizeof(Element)) == 0, "Invalid alignment setting");
|
| 73 |
+
|
| 74 |
+
// In ktrsc order.
|
| 75 |
+
using FilterExtent = array<int, NumDimensions>;
|
| 76 |
+
|
| 77 |
+
// Default cta tile shape: 32x32
|
| 78 |
+
static constexpr auto CTATileShape = make_shape(Int<4 * AlignmentBytes / static_cast<int>(sizeof(Element))>{}, Int<32>{});
|
| 79 |
+
// Default thread layout: (4, 32)
|
| 80 |
+
static constexpr auto ThreadLayout = make_layout(make_shape(Int<4>{}, Int<32>{}));
|
| 81 |
+
|
| 82 |
+
static constexpr uint32_t MaxThreadsPerBlock = 128;
|
| 83 |
+
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
| 84 |
+
|
| 85 |
+
using ArchTag = arch::Sm90;
|
| 86 |
+
|
| 87 |
+
// Default ctor
|
| 88 |
+
CUTLASS_HOST_DEVICE
|
| 89 |
+
ConvFilterFormatTransformer() {}
|
| 90 |
+
|
| 91 |
+
struct Arguments {
|
| 92 |
+
const void *src_ptr;
|
| 93 |
+
void *dst_ptr;
|
| 94 |
+
FilterExtent filter_extent;
|
| 95 |
+
};
|
| 96 |
+
|
| 97 |
+
struct Params {
|
| 98 |
+
using TensorSrc = decltype(make_tensor(make_gmem_ptr(recast_ptr<const Element>(nullptr)), make_layout(take<0,NumDimensions>(FilterExtent{}))));
|
| 99 |
+
using TensorDst = decltype(make_tensor(make_gmem_ptr(recast_ptr<Element>(nullptr)), make_layout(make_shape(int32_t(0), int32_t(0)))));
|
| 100 |
+
|
| 101 |
+
TensorSrc src;
|
| 102 |
+
TensorDst dst;
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
struct SharedStorage {
|
| 106 |
+
/* empty, no smem needed */
|
| 107 |
+
};
|
| 108 |
+
|
| 109 |
+
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
| 110 |
+
|
| 111 |
+
static Status
|
| 112 |
+
can_implement(Arguments const& args) {
|
| 113 |
+
bool implementable = true;
|
| 114 |
+
// alignment rule
|
| 115 |
+
{
|
| 116 |
+
int contiguous_dim = DstFormat == FilterFormat::CTRSK ? args.filter_extent[0] : args.filter_extent[NumDimensions - 1];
|
| 117 |
+
int align_element = AlignmentBytes / static_cast<int>(sizeof(Element));
|
| 118 |
+
|
| 119 |
+
implementable &= (contiguous_dim % align_element == 0);
|
| 120 |
+
|
| 121 |
+
if (!implementable) {
|
| 122 |
+
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Alignment setting is invalid.\n");
|
| 123 |
+
return Status::kInvalid;
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
return Status::kSuccess;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
static size_t
|
| 131 |
+
get_workspace_size(Arguments const& args) {
|
| 132 |
+
return 0;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
static dim3
|
| 136 |
+
get_block_shape() {
|
| 137 |
+
return dim3(size(shape(ThreadLayout)), 1, 1);
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
static dim3
|
| 141 |
+
get_grid_shape(Params const& params) {
|
| 142 |
+
auto dim_m = ceil_div(size<0>(shape(params.dst)), get<0>(CTATileShape));
|
| 143 |
+
auto dim_n = ceil_div(size<1>(shape(params.dst)), get<1>(CTATileShape));
|
| 144 |
+
|
| 145 |
+
return dim3(dim_m, dim_n, 1);
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
static cutlass::Status
|
| 149 |
+
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
|
| 150 |
+
CudaHostAdapter *cuda_adapter = nullptr) {
|
| 151 |
+
return Status::kSuccess;
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
static Params
|
| 155 |
+
to_underlying_arguments(Arguments const& args, void* workspace) {
|
| 156 |
+
auto k = args.filter_extent[0];
|
| 157 |
+
auto c = args.filter_extent[NumDimensions - 1];
|
| 158 |
+
auto srt = reverse(take<1,NumDimensions - 1>(args.filter_extent));
|
| 159 |
+
|
| 160 |
+
// source shape (s,r,t,k,c)
|
| 161 |
+
auto shape_src = flatten(make_shape(srt, k, c));
|
| 162 |
+
auto shape_dst = DstFormat == FilterFormat::CTRSK ? make_shape(k, c * product(srt)) : make_shape(c, k * product(srt));
|
| 163 |
+
|
| 164 |
+
auto src = make_tensor(make_gmem_ptr(recast_ptr<const Element>(args.src_ptr)), make_layout(shape_src));
|
| 165 |
+
auto dst = make_tensor(make_gmem_ptr(recast_ptr<Element>(args.dst_ptr)), make_layout(shape_dst));
|
| 166 |
+
|
| 167 |
+
return Params{src, dst};
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
CUTLASS_DEVICE
|
| 171 |
+
void operator()(Params const& params, char *smem_buf) {
|
| 172 |
+
// Tile the input tensor into blocks
|
| 173 |
+
auto block_coord = make_coord(blockIdx.x, blockIdx.y);
|
| 174 |
+
auto block_shape = make_shape(Int<4 * AlignmentBytes / static_cast<int>(sizeof(Element))>{}, Int<32>{});
|
| 175 |
+
// Default thread layout: (4, 32)
|
| 176 |
+
auto thread_layout = make_layout(make_shape(Int<4>{}, Int<32>{}));
|
| 177 |
+
auto vec_layout = make_layout(make_shape(Int<AlignmentBytes / static_cast<int>(sizeof(Element))>{}, Int<1>{}));
|
| 178 |
+
|
| 179 |
+
Tensor tile_D = local_tile(params.dst, block_shape, block_coord);
|
| 180 |
+
|
| 181 |
+
// Construct tiled copy
|
| 182 |
+
using AccessType = cutlass::AlignedArray<Element, size(vec_layout)>;
|
| 183 |
+
using Atom = Copy_Atom<UniversalCopy<AccessType>, Element>;
|
| 184 |
+
|
| 185 |
+
auto tiled_copy = make_tiled_copy(Atom{}, thread_layout, vec_layout);
|
| 186 |
+
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
|
| 187 |
+
Tensor thr_tile_D = thr_copy.partition_D(tile_D);
|
| 188 |
+
|
| 189 |
+
// shape (s, r, t)
|
| 190 |
+
auto shape_trs = take<0, NumDimensions - 2>(shape(params.src));
|
| 191 |
+
// strided_c = c for format CTRSK, strided_c = k for format KTRSC
|
| 192 |
+
auto strided_c = DstFormat == FilterFormat::CTRSK ? get<NumDimensions - 1>(shape(params.src)) : get<NumDimensions - 2>(shape(params.src));
|
| 193 |
+
// shape (s, r, t, c) for format CTRSK and shape (s, r, t, k) for format KTRSC
|
| 194 |
+
auto shape_ctrs = append<NumDimensions - 1>(shape_trs, strided_c);
|
| 195 |
+
auto srtc_coord = idx2crd(int(blockIdx.y * get<1>(block_shape) + threadIdx.x / size<0>(thread_layout)), shape_ctrs);
|
| 196 |
+
// index of k for format CTRSK and index of c for format KTRSC
|
| 197 |
+
auto n_layout = make_layout(make_shape(gridDim.x, size<0>(thread_layout)), make_stride(size<0>(block_shape), size<0>(vec_layout)));
|
| 198 |
+
int n_idx = n_layout(make_coord(blockIdx.x, threadIdx.x % size<0>(thread_layout)));
|
| 199 |
+
|
| 200 |
+
// Fragment to load from S and store to D
|
| 201 |
+
auto frag = make_fragment_like(thr_tile_D);
|
| 202 |
+
// Predicate tensor.
|
| 203 |
+
Tensor thr_tile_P = make_tensor<bool>(shape(thr_tile_D));
|
| 204 |
+
|
| 205 |
+
CUTLASS_PRAGMA_UNROLL
|
| 206 |
+
for (int i = 0; i < size(frag); ++i) {
|
| 207 |
+
auto srt_coord = take<0, NumDimensions - 2>(srtc_coord);
|
| 208 |
+
auto kc_coord = DstFormat == FilterFormat::CTRSK ?
|
| 209 |
+
make_coord(n_idx+i, get<NumDimensions - 2>(srtc_coord)) :
|
| 210 |
+
make_coord(get<NumDimensions - 2>(srtc_coord), n_idx+i);
|
| 211 |
+
auto coord = flatten(make_coord(srt_coord, kc_coord));
|
| 212 |
+
thr_tile_P(i) = elem_less(coord, shape(params.src));
|
| 213 |
+
if (thr_tile_P(i)) {
|
| 214 |
+
frag(i) = params.src(coord);
|
| 215 |
+
}
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
// Copy from RMEM to GMEM
|
| 219 |
+
copy_if(tiled_copy, thr_tile_P, frag, thr_tile_D);
|
| 220 |
+
}
|
| 221 |
+
};
|
| 222 |
+
|
| 223 |
+
} // namespace cutlass::transform::kernel
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp
ADDED
|
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Compress utils specific for SM90 structure sparse kernels
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cute/container/bit_field.hpp" // cute::bit_field
|
| 39 |
+
#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v, cute::uint_bit_t
|
| 40 |
+
#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor
|
| 41 |
+
#include "cute/algorithm/cooperative_copy.hpp" // cute::cooperative_copy
|
| 42 |
+
#include "cutlass/arch/arch.h" // cutlass::arch::Sm90
|
| 43 |
+
#include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter
|
| 44 |
+
#include "cutlass/cutlass.h" // cutlass::Status
|
| 45 |
+
#include "cutlass/gemm/gemm.h" // cutlass::TagToStrideA_t
|
| 46 |
+
#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up
|
| 47 |
+
#include "cutlass/kernel_hardware_info.h" // cutlass::KernelHardwareInfo
|
| 48 |
+
#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes
|
| 49 |
+
#include "cutlass/numeric_types.h" // cutlass::has_negative_zero_v
|
| 50 |
+
#include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter
|
| 51 |
+
|
| 52 |
+
namespace cutlass::transform::kernel {
|
| 53 |
+
|
| 54 |
+
using namespace cute;
|
| 55 |
+
|
| 56 |
+
template<
|
| 57 |
+
class ProblemShape_,
|
| 58 |
+
class ElementA_,
|
| 59 |
+
class LayoutATag_,
|
| 60 |
+
class SparseConfig_
|
| 61 |
+
>
|
| 62 |
+
class SM90StructuredSparseCompressor {
|
| 63 |
+
public:
|
| 64 |
+
using SparseConfig = SparseConfig_;
|
| 65 |
+
using ProblemShape = ProblemShape_;
|
| 66 |
+
|
| 67 |
+
// * EltA
|
| 68 |
+
using ElementA = ElementA_;
|
| 69 |
+
using ElementAUint = cute::uint_bit_t<cute::sizeof_bits_v<ElementA>>;
|
| 70 |
+
using ElementAMma = typename SparseConfig::ElementAMma;
|
| 71 |
+
using ElementAMmaRaw = typename SparseConfig::ElementAMmaRaw;
|
| 72 |
+
using ElementAMmaRawUnit = cute::uint_bit_t<cute::sizeof_bits_v<ElementAMmaRaw>>;
|
| 73 |
+
using ElementASparsity = typename SparseConfig::ElementASparsity;
|
| 74 |
+
using ElementAMmaSparsity = typename SparseConfig::ElementAMmaSparsity;
|
| 75 |
+
using ElementAUintCompressed = cute::sparse_elem<ElementASparsity{}, ElementAUint>;
|
| 76 |
+
using LayoutATag = LayoutATag_;
|
| 77 |
+
using LayoutA = LayoutATag;
|
| 78 |
+
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutATag>;
|
| 79 |
+
|
| 80 |
+
// * EltE
|
| 81 |
+
using ElementEMma = typename SparseConfig::ElementEMma;
|
| 82 |
+
using ElementEMmaRaw = typename SparseConfig::ElementEMmaRaw;
|
| 83 |
+
using ElementEMmaSparsity = typename SparseConfig::ElementEMmaSparsity;
|
| 84 |
+
// Data Type for storing one chunk's metadata
|
| 85 |
+
static constexpr int ElementEBitsPerChunk = typename SparseConfig::ElementEBitsPerChunk{};
|
| 86 |
+
CUTE_STATIC_ASSERT(ElementEBitsPerChunk == 4, "ElementEBitsPerChunk is 4 for SM90");
|
| 87 |
+
using ElementEChunk = cute::uint_bit_t<ElementEBitsPerChunk>;
|
| 88 |
+
CUTE_STATIC_ASSERT(cute::is_same_v<ElementEChunk, cute::uint4_t>, "ElementEChunk is uint4_t for SM90");
|
| 89 |
+
using ElementESparsityPerChunk = Int<ElementEMmaSparsity{} / (cute::sizeof_bits_v<ElementEMmaRaw> / ElementEBitsPerChunk)>;
|
| 90 |
+
|
| 91 |
+
// AtomE
|
| 92 |
+
using TensorEAtom = typename SparseConfig::TensorEAtom;
|
| 93 |
+
using TensorEAtomK = typename SparseConfig::TensorEAtomK;
|
| 94 |
+
using TensorEAtomM = typename SparseConfig::TensorEAtomM;
|
| 95 |
+
|
| 96 |
+
static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{};
|
| 97 |
+
static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{};
|
| 98 |
+
static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{};
|
| 99 |
+
static constexpr int LogicalElemsAMmaRawPerChunk = cutlass::ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw);
|
| 100 |
+
static constexpr int PhysicalElemsAMmaRawPerChunk = cutlass::ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw);
|
| 101 |
+
|
| 102 |
+
// * Alignment
|
| 103 |
+
static constexpr int TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{};
|
| 104 |
+
static constexpr int TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{};
|
| 105 |
+
static constexpr int TensorAAlignmentK = typename SparseConfig::TensorAAlignmentK{};
|
| 106 |
+
static constexpr int TensorAAlignmentM = typename SparseConfig::TensorAAlignmentM{};
|
| 107 |
+
|
| 108 |
+
// Required by `device_kernel`
|
| 109 |
+
static constexpr int MaxThreadsPerBlock = TensorEAtomM{};
|
| 110 |
+
static constexpr int MinBlocksPerMultiprocessor = 1;
|
| 111 |
+
using ArchTag = arch::Sm90;
|
| 112 |
+
|
| 113 |
+
struct SharedStorage {
|
| 114 |
+
ElementEMma cEsE[cute::size(TensorEAtom{})];
|
| 115 |
+
ElementAUintCompressed cACsAC[cute::size(TensorEAtom{})];
|
| 116 |
+
ElementAUint cAsA[cute::size(TensorEAtom{})];
|
| 117 |
+
};
|
| 118 |
+
|
| 119 |
+
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
| 120 |
+
|
| 121 |
+
struct TransformArguments {
|
| 122 |
+
void const* ptr_A{nullptr};
|
| 123 |
+
StrideA dA{};
|
| 124 |
+
void* ptr_ACompress{nullptr};
|
| 125 |
+
void* ptr_E{nullptr};
|
| 126 |
+
};
|
| 127 |
+
|
| 128 |
+
using TransformParams = TransformArguments;
|
| 129 |
+
|
| 130 |
+
struct Arguments {
|
| 131 |
+
ProblemShape problem_shape{};
|
| 132 |
+
TransformArguments transform{};
|
| 133 |
+
KernelHardwareInfo hw_info{};
|
| 134 |
+
};
|
| 135 |
+
|
| 136 |
+
struct Params {
|
| 137 |
+
ProblemShape problem_shape{};
|
| 138 |
+
TransformParams transform{};
|
| 139 |
+
KernelHardwareInfo hw_info{};
|
| 140 |
+
void* workspace = nullptr;
|
| 141 |
+
};
|
| 142 |
+
|
| 143 |
+
public:
|
| 144 |
+
static Params
|
| 145 |
+
to_underlying_arguments(Arguments const& args, void* workspace = nullptr) {
|
| 146 |
+
CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::to_underlying_arguments()");
|
| 147 |
+
return Params{{args.problem_shape},
|
| 148 |
+
{args.transform.ptr_A, args.transform.dA, args.transform.ptr_ACompress, args.transform.ptr_E},
|
| 149 |
+
{args.hw_info},
|
| 150 |
+
workspace};
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
static Status
|
| 154 |
+
can_implement(Arguments const& args) {
|
| 155 |
+
auto [M, N, K, L] = args.problem_shape;
|
| 156 |
+
if (K % LogicalElemsAPerChunk != 0) {
|
| 157 |
+
CUTLASS_TRACE_HOST("SM90 Sparse Compressor CAN NOT IMPLEMENT: GemmK not multiplier of logical chunk size");
|
| 158 |
+
return Status::kErrorInvalidProblem;
|
| 159 |
+
}
|
| 160 |
+
CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::can_implement() (True)");
|
| 161 |
+
return Status::kSuccess;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
static size_t
|
| 165 |
+
get_workspace_size(Arguments const& args) {
|
| 166 |
+
CUTLASS_UNUSED(args);
|
| 167 |
+
// Backward compatible with host compressor
|
| 168 |
+
CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::get_workspace_size() (" << SharedStorageSize << ")");
|
| 169 |
+
return SharedStorageSize;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
static Status
|
| 173 |
+
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr,
|
| 174 |
+
CudaHostAdapter *cuda_adapter = nullptr) {
|
| 175 |
+
CUTLASS_UNUSED(args);
|
| 176 |
+
CUTLASS_UNUSED(workspace);
|
| 177 |
+
CUTLASS_UNUSED(stream);
|
| 178 |
+
CUTLASS_UNUSED(cuda_adapter);
|
| 179 |
+
CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::initialize_workspace()");
|
| 180 |
+
return Status::kSuccess;
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
static dim3
|
| 184 |
+
get_grid_shape(Params const& params) {
|
| 185 |
+
constexpr int MaxAlignmentM = cutlass::const_max(TensorEAlignmentM, TensorAAlignmentM);
|
| 186 |
+
constexpr int MaxAlignmentK = cutlass::const_max(TensorEAlignmentK, TensorAAlignmentK);
|
| 187 |
+
const auto [GemmM, GemmN, GemmK, GemmL] = params.problem_shape;
|
| 188 |
+
|
| 189 |
+
const int GemmMAlignedMax = cutlass::round_up(GemmM, MaxAlignmentM);
|
| 190 |
+
const int GemmKAlignedMax = cutlass::round_up(GemmK, MaxAlignmentK);
|
| 191 |
+
|
| 192 |
+
const int gridDim_X = cutlass::ceil_div(GemmMAlignedMax, TensorEAtomM{});
|
| 193 |
+
const int gridDim_Y = cutlass::ceil_div(GemmKAlignedMax, TensorEAtomK{});
|
| 194 |
+
const int gridDim_Z = GemmL;
|
| 195 |
+
|
| 196 |
+
CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::get_grid_shape() ("
|
| 197 |
+
<< gridDim_X << ", "
|
| 198 |
+
<< gridDim_Y << ", "
|
| 199 |
+
<< gridDim_Z << ")");
|
| 200 |
+
return dim3(gridDim_X, gridDim_Y, gridDim_Z);
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
static dim3
|
| 204 |
+
get_block_shape() {
|
| 205 |
+
CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::get_block_shape() ("
|
| 206 |
+
<< MaxThreadsPerBlock << ", "
|
| 207 |
+
<< 1 << ", "
|
| 208 |
+
<< 1 << ")");
|
| 209 |
+
return dim3(MaxThreadsPerBlock, 1, 1);
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
CUTE_DEVICE
|
| 213 |
+
void
|
| 214 |
+
operator()(Params params, void* smem_buf = nullptr) {
|
| 215 |
+
run(params, smem_buf);
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
CUTE_DEVICE
|
| 219 |
+
static void
|
| 220 |
+
run(Params params, void* smem_buf = nullptr) {
|
| 221 |
+
structure_sparse_compress(params, smem_buf);
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
private:
|
| 225 |
+
|
| 226 |
+
struct MetadataOneChunk1to2 {
|
| 227 |
+
|
| 228 |
+
CUTE_DEVICE
|
| 229 |
+
void set_metadata_bits(int elt_log_idx, int elt_phy_idx) {
|
| 230 |
+
auto metadata_bits = [&]() -> uint8_t {
|
| 231 |
+
CUTLASS_ASSERT(elt_log_idx >= 0 && elt_log_idx < 2);
|
| 232 |
+
switch (elt_log_idx) {
|
| 233 |
+
case 0:
|
| 234 |
+
return 0b0100;
|
| 235 |
+
case 1:
|
| 236 |
+
return 0b1110;
|
| 237 |
+
default:
|
| 238 |
+
CUTE_GCC_UNREACHABLE;
|
| 239 |
+
}
|
| 240 |
+
};
|
| 241 |
+
|
| 242 |
+
storage_ |= (metadata_bits() << (4 * elt_phy_idx));
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
CUTE_DEVICE
|
| 247 |
+
ElementEChunk storage() const {
|
| 248 |
+
return ElementEChunk{storage_};
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
private:
|
| 252 |
+
uint8_t storage_ = 0b0000;
|
| 253 |
+
};
|
| 254 |
+
|
| 255 |
+
struct MetadataOneChunk2to4{
|
| 256 |
+
|
| 257 |
+
CUTE_DEVICE
|
| 258 |
+
void set_metadata_bits(int elt_log_idx, int elt_phy_idx) {
|
| 259 |
+
auto metadata_bits = [&]() -> uint8_t {
|
| 260 |
+
CUTLASS_ASSERT(elt_log_idx >= 0 && elt_log_idx < 4);
|
| 261 |
+
switch (elt_log_idx) {
|
| 262 |
+
case 0:
|
| 263 |
+
return 0b00;
|
| 264 |
+
case 1:
|
| 265 |
+
return 0b01;
|
| 266 |
+
case 2:
|
| 267 |
+
return 0b10;
|
| 268 |
+
case 3:
|
| 269 |
+
return 0b11;
|
| 270 |
+
default:
|
| 271 |
+
CUTLASS_ASSERT(false);
|
| 272 |
+
CUTE_GCC_UNREACHABLE;
|
| 273 |
+
return 0b00;
|
| 274 |
+
}
|
| 275 |
+
};
|
| 276 |
+
|
| 277 |
+
storage_ |= (metadata_bits() << (2 * elt_phy_idx));
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
CUTE_DEVICE
|
| 281 |
+
ElementEChunk storage() const {
|
| 282 |
+
return ElementEChunk{storage_};
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
private:
|
| 286 |
+
uint8_t storage_ = 0b0000;
|
| 287 |
+
};
|
| 288 |
+
|
| 289 |
+
using MetadataOneChunk = cute::conditional_t<SparseConfig::IsTF32,
|
| 290 |
+
MetadataOneChunk1to2,
|
| 291 |
+
MetadataOneChunk2to4>;
|
| 292 |
+
|
| 293 |
+
private:
|
| 294 |
+
|
| 295 |
+
CUTE_DEVICE
|
| 296 |
+
static void
|
| 297 |
+
structure_sparse_compress(Params params, void* smem_buf) {
|
| 298 |
+
// * Input Params
|
| 299 |
+
auto [GemmM, GemmN, GemmK, GemmL] = params.problem_shape;
|
| 300 |
+
auto [ptr_A, dA, ptr_ACompress, ptr_E] = params.transform;
|
| 301 |
+
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
| 302 |
+
|
| 303 |
+
[[maybe_unused]] const int gridDim_X = gridDim.x;
|
| 304 |
+
[[maybe_unused]] const int gridDim_Y = gridDim.y;
|
| 305 |
+
[[maybe_unused]] const int gridDim_Z = gridDim.z;
|
| 306 |
+
[[maybe_unused]] const int blockDim_X = blockDim.x;
|
| 307 |
+
|
| 308 |
+
// * Global Tensor Layout
|
| 309 |
+
const cute::Layout layout_gA = make_layout(make_shape(GemmM, GemmK, GemmL), dA);
|
| 310 |
+
const cute::Layout layout_gAC = SparseConfig::fill_layoutA(params.problem_shape);
|
| 311 |
+
const cute::Layout layout_gE = SparseConfig::fill_layoutE(params.problem_shape);
|
| 312 |
+
|
| 313 |
+
// * Construct Global Tensor
|
| 314 |
+
const cute::Tensor gA = make_tensor(make_gmem_ptr(cute::recast_ptr<ElementAUint>(ptr_A)), layout_gA);
|
| 315 |
+
cute::Tensor gAC_sparse = make_tensor(make_gmem_ptr(cute::recast_ptr<ElementAUintCompressed>(ptr_ACompress)), layout_gAC );
|
| 316 |
+
cute::Tensor gAC = cute::recast<ElementAUint>(gAC_sparse);
|
| 317 |
+
cute::Tensor gE_sparse = make_tensor(make_gmem_ptr(cute::recast_ptr<ElementEMma>(ptr_E)), layout_gE);
|
| 318 |
+
cute::Tensor gE = cute::recast<ElementEMmaRaw>(gE_sparse);
|
| 319 |
+
|
| 320 |
+
// * CTA Tensor Layout
|
| 321 |
+
using cAsA_layout_row = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{}), LayoutRight{}));
|
| 322 |
+
using cAsA_layout_col = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{}), LayoutLeft{}));
|
| 323 |
+
using cAsA_layout = cute::conditional_t<cute::is_same_v<LayoutATag, layout::RowMajor>, cAsA_layout_row, cAsA_layout_col>;
|
| 324 |
+
using cACsAC_layout = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{} / ElementASparsity{}), LayoutRight{}));
|
| 325 |
+
using cEsE_layout = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{} / ElementEMmaSparsity{}), LayoutRight{}));
|
| 326 |
+
|
| 327 |
+
CUTE_STATIC_ASSERT(cute::is_static_v<TensorEAtom>, "TensorEAtom needs to be static");
|
| 328 |
+
CUTE_STATIC_ASSERT(cute::is_static_v<cAsA_layout>, "cAsA_layout needs to be static");
|
| 329 |
+
CUTE_STATIC_ASSERT(cute::is_static_v<cACsAC_layout>, "cACsAC_layout needs to be static");
|
| 330 |
+
CUTE_STATIC_ASSERT(cute::is_static_v<cEsE_layout>, "cEsE_layout needs to be static");
|
| 331 |
+
|
| 332 |
+
const int blockIdx_X = blockIdx.x;
|
| 333 |
+
const int blockIdx_Y = blockIdx.y;
|
| 334 |
+
const int blockIdx_Z = blockIdx.z;
|
| 335 |
+
const int threadIdx_X = threadIdx.x;
|
| 336 |
+
|
| 337 |
+
// * Construct CTA Tensor
|
| 338 |
+
const auto cta_coord = make_coord(blockIdx_X, blockIdx_Y, blockIdx_Z);
|
| 339 |
+
cute::Tensor cAgA = cute::recast<ElementAMmaRawUnit>(local_tile(gA, shape(cAsA_layout{}), cta_coord));
|
| 340 |
+
cute::Tensor cACgAC = cute::recast<ElementAMmaRawUnit>(local_tile(gAC, shape(cACsAC_layout{}), cta_coord));
|
| 341 |
+
cute::Tensor cEgE = local_tile(gE, shape(cEsE_layout{}), cta_coord);
|
| 342 |
+
|
| 343 |
+
cute::Tensor cAsA = cute::recast<ElementAMmaRawUnit>(make_tensor(make_smem_ptr(cute::recast_ptr<ElementAUint>(shared_storage.cAsA)), cAsA_layout{}));
|
| 344 |
+
cute::Tensor cACsAC = cute::recast<ElementAMmaRawUnit>(make_tensor(make_smem_ptr(cute::recast_ptr<ElementAUint>(shared_storage.cACsAC)), cACsAC_layout{}));
|
| 345 |
+
cute::Tensor cEsE = make_tensor(make_smem_ptr(cute::recast_ptr<ElementEMmaRaw>(shared_storage.cEsE)), cEsE_layout{});
|
| 346 |
+
cute::Tensor cEsE_chunk = cute::recast<ElementEChunk>(cEsE);
|
| 347 |
+
|
| 348 |
+
// * Handle in unit of Chunk when compress
|
| 349 |
+
using OneChunkSizeA = Int<LogicalElemsAMmaRawPerChunk>;
|
| 350 |
+
using OneChunkSizeAC = Int<PhysicalElemsAMmaRawPerChunk>;
|
| 351 |
+
using OneChunkSizeE = Int<LogicalElemsAPerChunk / ElementESparsityPerChunk{}>;
|
| 352 |
+
using NumOneChunkK = Int<cutlass::ceil_div(TensorEAtomK{}, LogicalElemsAPerChunk)>;
|
| 353 |
+
|
| 354 |
+
cute::Tensor cAsA_log_chunk = logical_divide(cAsA, make_shape(_, OneChunkSizeA{}));
|
| 355 |
+
cute::Tensor cACsAC_log_chunk = logical_divide(cACsAC, make_shape(_, OneChunkSizeAC{}));
|
| 356 |
+
cute::Tensor cEsE_log_chunk = logical_divide(cEsE_chunk, make_shape(_, OneChunkSizeE{}));
|
| 357 |
+
|
| 358 |
+
// * Corner Case Handle
|
| 359 |
+
const auto GemmM_within_Cta = (GemmM - blockIdx_X * TensorEAtomM{} > TensorEAtomM{}) ? TensorEAtomM{} : GemmM - blockIdx_X * TensorEAtomM{};
|
| 360 |
+
const auto GemmK_within_Cta = ( (GemmK - blockIdx_Y * TensorEAtomK{} > TensorEAtomK{}) ? TensorEAtomK{} : GemmK - blockIdx_Y * TensorEAtomK{} ) / ElemsARawPerElementAMmaRaw;
|
| 361 |
+
const auto GemmK_NumOneChunk_within_Cta = GemmK_within_Cta / LogicalElemsAMmaRawPerChunk;
|
| 362 |
+
|
| 363 |
+
const auto GemmMAlignedAC = cutlass::round_up(GemmM, TensorAAlignmentM);
|
| 364 |
+
const auto GemmKAlignedAC = cutlass::round_up(GemmK, TensorAAlignmentK);
|
| 365 |
+
const auto GemmMAlignedAC_within_Cta = (GemmMAlignedAC - blockIdx_X * TensorEAtomM{} > TensorEAtomM{}) ? TensorEAtomM{} : GemmMAlignedAC - blockIdx_X * TensorEAtomM{};
|
| 366 |
+
const auto GemmKAlignedAC_within_Cta = ( (GemmKAlignedAC - blockIdx_Y * TensorEAtomK{} > TensorEAtomK{}) ? TensorEAtomK{} : GemmKAlignedAC - blockIdx_Y * TensorEAtomK{} ) / ElemsARawPerElementAMmaRaw;
|
| 367 |
+
|
| 368 |
+
// * Clear CTA Smem Tensor
|
| 369 |
+
cooperative_clear<MaxThreadsPerBlock>(threadIdx_X, cACsAC);
|
| 370 |
+
cooperative_clear<MaxThreadsPerBlock>(threadIdx_X, cEsE);
|
| 371 |
+
|
| 372 |
+
// * Input CTA Tensor G to S
|
| 373 |
+
if (GemmM_within_Cta == TensorEAtomM{} && GemmK_within_Cta == TensorEAtomK{}) {
|
| 374 |
+
copy_vec_pred<false, LayoutATag>(cAgA, cAsA, threadIdx_X, GemmM_within_Cta, GemmK_within_Cta);
|
| 375 |
+
}
|
| 376 |
+
else {
|
| 377 |
+
copy_vec_pred<true, LayoutATag>(cAgA, cAsA, threadIdx_X, GemmM_within_Cta, GemmK_within_Cta);
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
// Construct a sign bit mask for handling negative zeros
|
| 381 |
+
ElementAMmaRawUnit sign_mask = ElementAMmaRawUnit{ 0 };
|
| 382 |
+
if constexpr (has_negative_zero_v<ElementA>) {
|
| 383 |
+
ElementAMmaRawUnit one_sign_mask = static_cast<ElementAMmaRawUnit>(~(ElementAMmaRawUnit{ 1 } << (cute::sizeof_bits_v<ElementA> - 1)));
|
| 384 |
+
for (int i = 0; i < sizeof(ElementAMmaRawUnit) / sizeof(ElementAUint); ++i) {
|
| 385 |
+
sign_mask = static_cast<ElementAMmaRawUnit>((int32_t)sign_mask | (int32_t)one_sign_mask << (i * cute::sizeof_bits_v<ElementA>));
|
| 386 |
+
}
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
// * Compress
|
| 390 |
+
// cACsAC is always row major order
|
| 391 |
+
// TensorEAtomM threads perform the compression, each thread compress one row
|
| 392 |
+
const int row_i = threadIdx_X;
|
| 393 |
+
if (row_i < GemmM_within_Cta) {
|
| 394 |
+
|
| 395 |
+
CUTE_UNROLL
|
| 396 |
+
for (int col_chunk_i = 0; col_chunk_i < NumOneChunkK{}; ++col_chunk_i) {
|
| 397 |
+
if (col_chunk_i < GemmK_NumOneChunk_within_Cta) {
|
| 398 |
+
// Compress is handled in unit of ElementAMmaRawUnit
|
| 399 |
+
cute::Tensor tAsA = cAsA_log_chunk(row_i, make_coord(_, col_chunk_i));
|
| 400 |
+
cute::Tensor tACsAC = cACsAC_log_chunk(row_i, make_coord(_, col_chunk_i));
|
| 401 |
+
cute::Tensor tEsE = cEsE_log_chunk(row_i, make_coord(_, col_chunk_i));
|
| 402 |
+
|
| 403 |
+
int non_zero_cnt = 0;
|
| 404 |
+
// None zero element indx
|
| 405 |
+
// e.g.
|
| 406 |
+
// 2:4 sparsity [x 0 0 x]
|
| 407 |
+
// non_zero_elt_log_idx = [0, 3]
|
| 408 |
+
int non_zero_elt_log_idx[OneChunkSizeAC{}] = { 0 };
|
| 409 |
+
|
| 410 |
+
// * Find None Zero Element Idx within Chunk
|
| 411 |
+
CUTE_UNROLL
|
| 412 |
+
for (int elt_log_idx = 0; elt_log_idx < OneChunkSizeA{}; ++elt_log_idx) {
|
| 413 |
+
ElementAMmaRawUnit elem_A = tAsA[elt_log_idx];
|
| 414 |
+
|
| 415 |
+
// Handle negative 0
|
| 416 |
+
ElementAMmaRawUnit masked_elem_A = elem_A;
|
| 417 |
+
if constexpr (has_negative_zero_v<ElementA>) {
|
| 418 |
+
masked_elem_A = elem_A & sign_mask;
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
if (masked_elem_A != ElementAMmaRawUnit{0}) {
|
| 422 |
+
non_zero_elt_log_idx[non_zero_cnt] = elt_log_idx;
|
| 423 |
+
tACsAC[non_zero_cnt] = elem_A;
|
| 424 |
+
non_zero_cnt++;
|
| 425 |
+
}
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
// * Corner Case for 2:4 sparsity
|
| 429 |
+
if constexpr (cute::sizeof_bits_v<ElementAMmaRawUnit> < 32) {
|
| 430 |
+
// i.e. [0 0 0 x] -> [(0) 0 0 x]
|
| 431 |
+
if (non_zero_cnt == 1 && non_zero_elt_log_idx[0] == 3) {
|
| 432 |
+
tACsAC[1] = tACsAC[0];
|
| 433 |
+
tACsAC[0] = ElementAMmaRawUnit{0};
|
| 434 |
+
non_zero_elt_log_idx[0] = 0;
|
| 435 |
+
non_zero_elt_log_idx[1] = 3;
|
| 436 |
+
}
|
| 437 |
+
// i.e. [0 0 x 0] -> [0 0 x (0)]
|
| 438 |
+
// i.e. [0 x 0 0] -> [0 x 0 (0)]
|
| 439 |
+
// i.e. [x 0 0 0] -> [x 0 0 (0)]
|
| 440 |
+
else if (non_zero_cnt == 1) {
|
| 441 |
+
tACsAC[1] = ElementAMmaRawUnit{0};
|
| 442 |
+
non_zero_elt_log_idx[1] = 3;
|
| 443 |
+
}
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
// * Set Metadata Bits
|
| 447 |
+
MetadataOneChunk metadata_one_chunk;
|
| 448 |
+
CUTE_UNROLL
|
| 449 |
+
for (int elt_phy_idx = 0; elt_phy_idx < OneChunkSizeAC{}; elt_phy_idx++) {
|
| 450 |
+
metadata_one_chunk.set_metadata_bits(non_zero_elt_log_idx[elt_phy_idx], elt_phy_idx);
|
| 451 |
+
}
|
| 452 |
+
tEsE[0] = metadata_one_chunk.storage();
|
| 453 |
+
|
| 454 |
+
}
|
| 455 |
+
else {
|
| 456 |
+
break;
|
| 457 |
+
}
|
| 458 |
+
}
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
// * Sync after Compress
|
| 462 |
+
__syncthreads();
|
| 463 |
+
|
| 464 |
+
// * Output Cta Tensor S to G
|
| 465 |
+
if (GemmM_within_Cta > 0 && GemmK_within_Cta > 0) {
|
| 466 |
+
constexpr int MaxVecBits = 128; // STG.128
|
| 467 |
+
cute::cooperative_copy<MaxThreadsPerBlock, MaxVecBits>(threadIdx_X, cEsE, cEgE);
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
if (GemmMAlignedAC_within_Cta == TensorEAtomM{} && GemmKAlignedAC_within_Cta == TensorEAtomK{}) {
|
| 471 |
+
copy_vec_pred<false, LayoutATag>(cACsAC, cACgAC, threadIdx_X, GemmMAlignedAC_within_Cta, (GemmKAlignedAC_within_Cta / ElementASparsity::value));
|
| 472 |
+
}
|
| 473 |
+
else {
|
| 474 |
+
copy_vec_pred<true, LayoutATag>(cACsAC, cACgAC, threadIdx_X, GemmMAlignedAC_within_Cta, (GemmKAlignedAC_within_Cta / ElementASparsity::value));
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
} // end of structure_sparse_compress()
|
| 478 |
+
|
| 479 |
+
template<uint32_t NumThreads,
|
| 480 |
+
typename TensorSrc>
|
| 481 |
+
CUTE_DEVICE
|
| 482 |
+
static void
|
| 483 |
+
cooperative_clear(
|
| 484 |
+
uint32_t const& tid,
|
| 485 |
+
TensorSrc dSrc) {
|
| 486 |
+
|
| 487 |
+
auto dSrctSrc = local_partition(dSrc, make_layout(make_shape(NumThreads, _1{})), tid);
|
| 488 |
+
cute::clear(dSrctSrc);
|
| 489 |
+
|
| 490 |
+
// Sync all thread data access
|
| 491 |
+
__syncthreads();
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
template <bool pred,
|
| 495 |
+
typename LayoutTag,
|
| 496 |
+
typename TensorSrc,
|
| 497 |
+
typename TensorDst>
|
| 498 |
+
CUTE_DEVICE
|
| 499 |
+
static void
|
| 500 |
+
copy_vec_pred(
|
| 501 |
+
TensorSrc dSrc,
|
| 502 |
+
TensorDst dDst,
|
| 503 |
+
int threadIdx_X,
|
| 504 |
+
int valid_rows,
|
| 505 |
+
int valid_cols) {
|
| 506 |
+
|
| 507 |
+
constexpr bool IsRowMajor = cute::is_same_v<LayoutTag, cutlass::layout::RowMajor>;
|
| 508 |
+
using Element = typename TensorSrc::element_type;
|
| 509 |
+
constexpr bool IsQmmaF6 = cute::sizeof_bits_v<Element> == 6;
|
| 510 |
+
|
| 511 |
+
CUTE_STATIC_ASSERT(cute::is_static_v<decltype(shape(dSrc))>, "shape(dSrc) needs to be static");
|
| 512 |
+
CUTE_STATIC_ASSERT(cute::is_static_v<decltype(shape(dDst))>, "shape(dDst) needs to be static");
|
| 513 |
+
CUTE_STATIC_ASSERT(cute::sizeof_bits_v<typename TensorSrc::element_type> == cute::sizeof_bits_v<typename TensorDst::element_type>,
|
| 514 |
+
"dSrc and dDst need to have same element bit width");
|
| 515 |
+
CUTE_STATIC_ASSERT(cute::size(dSrc) == cute::size(dDst), "dSrc and dDst need to have same size");
|
| 516 |
+
|
| 517 |
+
// ValueShape
|
| 518 |
+
using ValueShape =
|
| 519 |
+
cute::conditional_t<IsQmmaF6,
|
| 520 |
+
Shape<Int<1>, Int<1>>,
|
| 521 |
+
cute::conditional_t<IsRowMajor,
|
| 522 |
+
Shape<Int<1>, Int<128 / sizeof_bits_v<Element>>>,
|
| 523 |
+
Shape<Int<128 / sizeof_bits_v<Element>>, Int<1>>>
|
| 524 |
+
>;
|
| 525 |
+
|
| 526 |
+
constexpr int ValueShapeRows = shape<0>(ValueShape{});
|
| 527 |
+
constexpr int ValueShapeCols = shape<1>(ValueShape{});
|
| 528 |
+
|
| 529 |
+
// ThreadShape
|
| 530 |
+
using ThreadShape =
|
| 531 |
+
cute::conditional_t<IsQmmaF6,
|
| 532 |
+
cute::conditional_t<IsRowMajor,
|
| 533 |
+
Shape<Int<MaxThreadsPerBlock>, Int<1>>,
|
| 534 |
+
Shape<Int<1>, Int<MaxThreadsPerBlock>>>,
|
| 535 |
+
cute::conditional_t<IsRowMajor,
|
| 536 |
+
Shape<Int<MaxThreadsPerBlock / (shape<1>(dSrc) / ValueShapeCols)>, Int< (shape<1>(dSrc) / ValueShapeCols)>>,
|
| 537 |
+
Shape<Int< (shape<0>(dSrc) / ValueShapeRows)>, Int<MaxThreadsPerBlock / (shape<0>(dSrc) / ValueShapeRows)>>>
|
| 538 |
+
>;
|
| 539 |
+
|
| 540 |
+
constexpr int ThreadShapeRows = shape<0>(ThreadShape{});
|
| 541 |
+
constexpr int ThreadShapeCols = shape<1>(ThreadShape{});
|
| 542 |
+
|
| 543 |
+
const int threadIdx_X_row = threadIdx_X / ThreadShapeCols;
|
| 544 |
+
const int threadIdx_X_col = threadIdx_X % ThreadShapeCols;
|
| 545 |
+
|
| 546 |
+
// Row Major
|
| 547 |
+
if constexpr (IsRowMajor) {
|
| 548 |
+
CUTE_UNROLL
|
| 549 |
+
for (int iter_row_blk = 0; iter_row_blk < cutlass::ceil_div(shape<0>(dSrc), ThreadShapeRows * ValueShapeRows); ++iter_row_blk) {
|
| 550 |
+
CUTE_UNROLL
|
| 551 |
+
for (int col_chunk_i = 0; col_chunk_i < cutlass::ceil_div(shape<1>(dSrc) , ThreadShapeCols * ValueShapeCols); ++col_chunk_i) {
|
| 552 |
+
CUTE_UNROLL
|
| 553 |
+
for (int iter_row_thr = 0; iter_row_thr < ValueShapeRows; ++iter_row_thr) {
|
| 554 |
+
CUTE_UNROLL
|
| 555 |
+
for (int iter_col_thr = 0; iter_col_thr < ValueShapeCols; ++iter_col_thr) {
|
| 556 |
+
const int row_i = (iter_row_blk * ThreadShapeRows + threadIdx_X_row) * ValueShapeRows + iter_row_thr;
|
| 557 |
+
const int col_i = (col_chunk_i * ThreadShapeCols + threadIdx_X_col) * ValueShapeCols + iter_col_thr;
|
| 558 |
+
if constexpr ( (not pred) and (not IsQmmaF6) ) {
|
| 559 |
+
dDst(row_i, col_i) = dSrc(row_i, col_i);
|
| 560 |
+
}
|
| 561 |
+
else {
|
| 562 |
+
if (row_i < valid_rows && col_i < valid_cols) {
|
| 563 |
+
dDst(row_i, col_i) = dSrc(row_i, col_i);
|
| 564 |
+
}
|
| 565 |
+
}
|
| 566 |
+
}
|
| 567 |
+
}
|
| 568 |
+
}
|
| 569 |
+
}
|
| 570 |
+
}
|
| 571 |
+
// Col Major
|
| 572 |
+
else {
|
| 573 |
+
CUTE_UNROLL
|
| 574 |
+
for (int col_chunk_i = 0; col_chunk_i < cutlass::ceil_div(shape<1>(dSrc) , ThreadShapeCols * ValueShapeCols); ++col_chunk_i) {
|
| 575 |
+
CUTE_UNROLL
|
| 576 |
+
for (int iter_row_blk = 0; iter_row_blk < cutlass::ceil_div(shape<0>(dSrc), ThreadShapeRows * ValueShapeRows); ++iter_row_blk) {
|
| 577 |
+
CUTE_UNROLL
|
| 578 |
+
for (int iter_col_thr = 0; iter_col_thr < ValueShapeCols; ++iter_col_thr) {
|
| 579 |
+
CUTE_UNROLL
|
| 580 |
+
for (int iter_row_thr = 0; iter_row_thr < ValueShapeRows; ++iter_row_thr) {
|
| 581 |
+
const int row_i = (iter_row_blk * ThreadShapeRows + threadIdx_X_row) * ValueShapeRows + iter_row_thr;
|
| 582 |
+
const int col_i = (col_chunk_i * ThreadShapeCols + threadIdx_X_col) * ValueShapeCols + iter_col_thr;
|
| 583 |
+
if constexpr ( (not pred) and (not IsQmmaF6) ) {
|
| 584 |
+
dDst(row_i, col_i) = dSrc(row_i, col_i);
|
| 585 |
+
}
|
| 586 |
+
else {
|
| 587 |
+
if (row_i < valid_rows && col_i < valid_cols) {
|
| 588 |
+
dDst(row_i, col_i) = dSrc(row_i, col_i);
|
| 589 |
+
}
|
| 590 |
+
}
|
| 591 |
+
}
|
| 592 |
+
}
|
| 593 |
+
}
|
| 594 |
+
}
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
// Sync all thread data access
|
| 598 |
+
__syncthreads();
|
| 599 |
+
} // end of copy_vec_pred()
|
| 600 |
+
|
| 601 |
+
};
|
| 602 |
+
|
| 603 |
+
} // namespace cutlass::transform::kernel
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sparse_gemm_compressor.hpp
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Compress utils for structured sparse kernels
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include <algorithm> // std::fill
|
| 39 |
+
#include <array> // std::array
|
| 40 |
+
#include <random> // std::mt19937
|
| 41 |
+
|
| 42 |
+
#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v
|
| 43 |
+
#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor
|
| 44 |
+
#include "cutlass/arch/arch.h" // cutlass::arch::SmXY
|
| 45 |
+
#include "cutlass/detail/dependent_false.hpp" // cutlass::detail::dependent_false
|
| 46 |
+
#include "cutlass/gemm/gemm.h" // cutlass::TagToStrideA_t
|
| 47 |
+
#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up
|
| 48 |
+
#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes
|
| 49 |
+
|
| 50 |
+
#include "cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp"
|
| 51 |
+
|
| 52 |
+
namespace cutlass::transform::kernel {
|
| 53 |
+
|
| 54 |
+
template<
|
| 55 |
+
class ProblemShape_,
|
| 56 |
+
class ElementA_,
|
| 57 |
+
class LayoutATag_,
|
| 58 |
+
class SparseConfig_
|
| 59 |
+
>
|
| 60 |
+
class StructuredSparseCompressorUtility {
|
| 61 |
+
public:
|
| 62 |
+
using SparseConfig = SparseConfig_;
|
| 63 |
+
using ProblemShape = ProblemShape_;
|
| 64 |
+
|
| 65 |
+
//* EltA
|
| 66 |
+
using ElementA = ElementA_;
|
| 67 |
+
using LayoutATag = LayoutATag_;
|
| 68 |
+
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutATag>;
|
| 69 |
+
using ElementAMmaRaw = typename SparseConfig::ElementAMmaRaw;
|
| 70 |
+
using ElementASparsity = typename SparseConfig::ElementASparsity;
|
| 71 |
+
using ElementAMmaSparsity = typename SparseConfig::ElementAMmaSparsity;
|
| 72 |
+
|
| 73 |
+
//* EltE
|
| 74 |
+
using ElementEMmaRaw = typename SparseConfig::ElementEMmaRaw;
|
| 75 |
+
using ElementEMmaSparsity = typename SparseConfig::ElementEMmaSparsity;
|
| 76 |
+
|
| 77 |
+
//* AtomE
|
| 78 |
+
using TensorEAtom = typename SparseConfig::TensorEAtom;
|
| 79 |
+
using TensorEAtomK = typename SparseConfig::TensorEAtomK;
|
| 80 |
+
using TensorEAtomM = typename SparseConfig::TensorEAtomM;
|
| 81 |
+
|
| 82 |
+
static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{};
|
| 83 |
+
static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{};
|
| 84 |
+
static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{};
|
| 85 |
+
static constexpr int LogicalElemsAMmaRawPerChunk = cutlass::ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw);
|
| 86 |
+
static constexpr int PhysicalElemsAMmaRawPerChunk = cutlass::ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw);
|
| 87 |
+
|
| 88 |
+
//* Alignment
|
| 89 |
+
static constexpr int TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{};
|
| 90 |
+
static constexpr int TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{};
|
| 91 |
+
static constexpr int TensorAAlignmentK = typename SparseConfig::TensorAAlignmentK{};
|
| 92 |
+
static constexpr int TensorAAlignmentM = typename SparseConfig::TensorAAlignmentM{};
|
| 93 |
+
|
| 94 |
+
StructuredSparseCompressorUtility() = default;
|
| 95 |
+
|
| 96 |
+
StructuredSparseCompressorUtility(ProblemShape problem, StrideA dA) {
|
| 97 |
+
set_problem_size(problem, dA);
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
void set_problem_size(ProblemShape problem, StrideA dA_) {
|
| 101 |
+
M = cute::size<0>(problem);
|
| 102 |
+
K = cute::size<2>(problem);
|
| 103 |
+
L = cute::size<3>(problem);
|
| 104 |
+
|
| 105 |
+
// The following three vars are logical elem count!
|
| 106 |
+
K_alignedA = round_up(K, TensorAAlignmentK);
|
| 107 |
+
M_alignedA = round_up(M, TensorAAlignmentM);
|
| 108 |
+
K_alignedE = round_up(K, TensorEAlignmentK);
|
| 109 |
+
M_alignedE = round_up(M, TensorEAlignmentM);
|
| 110 |
+
|
| 111 |
+
dA = dA_;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
/**
|
| 115 |
+
* @brief Get the TensorE number of ElementE along K after alignment requirement
|
| 116 |
+
*
|
| 117 |
+
* @return int : number of ElementE (uint8_t) along K-dim
|
| 118 |
+
*/
|
| 119 |
+
int get_metadata_m_physical() const {
|
| 120 |
+
return M_alignedE;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
/**
|
| 124 |
+
* @brief Get the TensorE number of ElementE along M after alignment requirement
|
| 125 |
+
*
|
| 126 |
+
* @return int : number of ElementE (uint8_t) along M-dim
|
| 127 |
+
*/
|
| 128 |
+
int get_metadata_k_physical() const {
|
| 129 |
+
return K_alignedE / ElementEMmaSparsity{};
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
/**
|
| 133 |
+
* @brief Get the TensorACompressed number of ElementA along K after alignment requirement
|
| 134 |
+
*
|
| 135 |
+
* @return int : number of ElementA along K-dim
|
| 136 |
+
*/
|
| 137 |
+
int get_tensorA_k_physical() const {
|
| 138 |
+
return K_alignedA / ElementASparsity{};
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
/**
|
| 142 |
+
* @brief Get the TensorACompressed number of ElementA along M after alignment requirement
|
| 143 |
+
*
|
| 144 |
+
* @return int : number of ElementA along M-dim
|
| 145 |
+
*/
|
| 146 |
+
int get_tensorA_m_physical() const {
|
| 147 |
+
return M_alignedA;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
/**
|
| 151 |
+
* @brief Get the TensorACompressed Bytes
|
| 152 |
+
*
|
| 153 |
+
* @return uint64_t bytes
|
| 154 |
+
*/
|
| 155 |
+
uint64_t get_compressed_tensor_A_bytes() const {
|
| 156 |
+
const auto tensor_a_comp_num_elt_a = get_tensorA_m_physical() * get_tensorA_k_physical() * L;
|
| 157 |
+
const auto tensor_a_comp_bytes = cutlass::bits_to_bytes<uint64_t>(tensor_a_comp_num_elt_a * cute::sizeof_bits_v<ElementA>);
|
| 158 |
+
return tensor_a_comp_bytes;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
/**
|
| 162 |
+
* @brief Get the TensorA Bytes
|
| 163 |
+
*
|
| 164 |
+
* @return uint64_t bytes
|
| 165 |
+
*/
|
| 166 |
+
uint64_t get_raw_tensor_A_bytes() const {
|
| 167 |
+
const auto tensor_a_num_elt_a = uint64_t(M) * uint64_t(K) * uint64_t(L);
|
| 168 |
+
const auto tensor_a_bytes = cutlass::bits_to_bytes<uint64_t>(tensor_a_num_elt_a * cute::sizeof_bits_v<ElementA>);
|
| 169 |
+
return tensor_a_bytes;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
/**
|
| 173 |
+
* @brief Get the TensorE Bytes
|
| 174 |
+
*
|
| 175 |
+
* @return uint64_t bytes
|
| 176 |
+
*/
|
| 177 |
+
uint64_t get_tensor_E_bytes() const {
|
| 178 |
+
const auto tensor_e_num_elt_a = uint64_t(get_metadata_m_physical()) * uint64_t(get_metadata_k_physical()) * uint64_t(L);
|
| 179 |
+
const auto tensor_e_bytes = cutlass::bits_to_bytes<uint64_t>(tensor_e_num_elt_a * cute::sizeof_bits_v<ElementEMmaRaw>);
|
| 180 |
+
return tensor_e_bytes;
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
constexpr auto fill_layoutA_from_compressor() const {
|
| 184 |
+
return SparseConfig::fill_layoutA(cute::make_tuple(M,_1{},K,L));
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
constexpr auto fill_layoutE_from_compressor() const {
|
| 188 |
+
return SparseConfig::fill_layoutE(cute::make_tuple(M,_1{},K,L));
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
void structure_sparse_zero_mask_fill(void* host_a_ptr, uint64_t seed) {
|
| 192 |
+
|
| 193 |
+
constexpr int ChunkSize = LogicalElemsAMmaRawPerChunk;
|
| 194 |
+
using ChunkElement = cute::uint_bit_t<cute::sizeof_bits_v<ElementAMmaRaw>>;
|
| 195 |
+
|
| 196 |
+
cute::Tensor gA_eltA = cute::make_tensor(
|
| 197 |
+
cute::recast_ptr<ElementA>(host_a_ptr),
|
| 198 |
+
cute::make_layout(make_shape(M, K, L), dA));
|
| 199 |
+
|
| 200 |
+
// Input TensorA is handled in unit of ElementAMmaRaw instead of ElementA
|
| 201 |
+
cute::Tensor gA = cute::recast<ChunkElement>(gA_eltA);
|
| 202 |
+
|
| 203 |
+
// Extract out the Chunk from K-mode
|
| 204 |
+
Tensor gA_chunk = cute::zipped_divide(gA, cute::Shape<_1,cute::Int<ChunkSize>>{}); // (Chunk, Rest)
|
| 205 |
+
|
| 206 |
+
// Half of the data is zero to indicate sparsityA = 2
|
| 207 |
+
std::array<int, ChunkSize> nnzb_indicator{};
|
| 208 |
+
for (size_t i = 1; i < nnzb_indicator.size(); i += 2) {
|
| 209 |
+
nnzb_indicator.at(i) = 1;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
std::mt19937 rng(seed);
|
| 213 |
+
auto rest_shape = cute::shape<1>(gA_chunk);
|
| 214 |
+
for (auto iter = cute::make_coord_iterator(rest_shape); iter != cute::ForwardCoordIteratorSentinel{}; ++iter) {
|
| 215 |
+
std::shuffle(nnzb_indicator.begin(), nnzb_indicator.end(), rng);
|
| 216 |
+
for (int c = 0; c < size<0>(gA_chunk); ++c) { // for each elem within chunk
|
| 217 |
+
if (nnzb_indicator[c] == 0) {
|
| 218 |
+
gA_chunk(c, *iter) = ChunkElement{0};
|
| 219 |
+
}
|
| 220 |
+
} // end of within chunk
|
| 221 |
+
} // end of chunk_idx
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
int M{-1};
|
| 225 |
+
int K{-1};
|
| 226 |
+
int L{-1};
|
| 227 |
+
StrideA dA{};
|
| 228 |
+
|
| 229 |
+
private:
|
| 230 |
+
int K_alignedA{-1};
|
| 231 |
+
int M_alignedA{-1};
|
| 232 |
+
int K_alignedE{-1};
|
| 233 |
+
int M_alignedE{-1};
|
| 234 |
+
};
|
| 235 |
+
|
| 236 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 237 |
+
|
| 238 |
+
template<
|
| 239 |
+
class ProblemShape,
|
| 240 |
+
class ElementA,
|
| 241 |
+
class LayoutATag,
|
| 242 |
+
class SparseConfig,
|
| 243 |
+
class ArchTag
|
| 244 |
+
>
|
| 245 |
+
struct StructuredSparseCompressorSelector {
|
| 246 |
+
static_assert(cutlass::detail::dependent_false<ArchTag>,
|
| 247 |
+
"Could not select a structured sparse compressor for given parameters.");
|
| 248 |
+
};
|
| 249 |
+
|
| 250 |
+
template<
|
| 251 |
+
class ProblemShape,
|
| 252 |
+
class ElementA,
|
| 253 |
+
class LayoutATag,
|
| 254 |
+
class SparseConfig
|
| 255 |
+
>
|
| 256 |
+
struct StructuredSparseCompressorSelector<
|
| 257 |
+
ProblemShape,
|
| 258 |
+
ElementA,
|
| 259 |
+
LayoutATag,
|
| 260 |
+
SparseConfig,
|
| 261 |
+
arch::Sm90> {
|
| 262 |
+
using Compressor = SM90StructuredSparseCompressor<
|
| 263 |
+
ProblemShape,
|
| 264 |
+
ElementA,
|
| 265 |
+
LayoutATag,
|
| 266 |
+
SparseConfig
|
| 267 |
+
>;
|
| 268 |
+
};
|
| 269 |
+
|
| 270 |
+
template<
|
| 271 |
+
class ProblemShape,
|
| 272 |
+
class ElementA,
|
| 273 |
+
class LayoutATag,
|
| 274 |
+
class SparseConfig
|
| 275 |
+
>
|
| 276 |
+
struct StructuredSparseCompressorSelector<
|
| 277 |
+
ProblemShape,
|
| 278 |
+
ElementA,
|
| 279 |
+
LayoutATag,
|
| 280 |
+
SparseConfig,
|
| 281 |
+
arch::Sm100> {
|
| 282 |
+
using Compressor = SM90StructuredSparseCompressor<
|
| 283 |
+
ProblemShape,
|
| 284 |
+
ElementA,
|
| 285 |
+
LayoutATag,
|
| 286 |
+
SparseConfig
|
| 287 |
+
>;
|
| 288 |
+
};
|
| 289 |
+
|
| 290 |
+
template<
|
| 291 |
+
class ProblemShape,
|
| 292 |
+
class ElementA,
|
| 293 |
+
class LayoutATag,
|
| 294 |
+
class SparseConfig
|
| 295 |
+
>
|
| 296 |
+
struct StructuredSparseCompressorSelector<
|
| 297 |
+
ProblemShape,
|
| 298 |
+
ElementA,
|
| 299 |
+
LayoutATag,
|
| 300 |
+
SparseConfig,
|
| 301 |
+
arch::Sm120> {
|
| 302 |
+
using Compressor = SM90StructuredSparseCompressor<
|
| 303 |
+
ProblemShape,
|
| 304 |
+
ElementA,
|
| 305 |
+
LayoutATag,
|
| 306 |
+
SparseConfig
|
| 307 |
+
>;
|
| 308 |
+
};
|
| 309 |
+
|
| 310 |
+
template<
|
| 311 |
+
class ProblemShape,
|
| 312 |
+
class ElementA,
|
| 313 |
+
class LayoutATag,
|
| 314 |
+
class SparseConfig,
|
| 315 |
+
class ArchTag
|
| 316 |
+
>
|
| 317 |
+
using StructuredSparseCompressor = typename StructuredSparseCompressorSelector<
|
| 318 |
+
ProblemShape,
|
| 319 |
+
ElementA,
|
| 320 |
+
LayoutATag,
|
| 321 |
+
SparseConfig,
|
| 322 |
+
ArchTag
|
| 323 |
+
>::Compressor;
|
| 324 |
+
|
| 325 |
+
} // End namespace cutlass::transform::kernel
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/pitch_linear_thread_map.h
ADDED
|
@@ -0,0 +1,926 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates implementing how threads are mapped to a given tile.
|
| 33 |
+
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#include "cutlass/array.h"
|
| 40 |
+
#include "cutlass/coord.h"
|
| 41 |
+
#include "cutlass/predicate_vector.h"
|
| 42 |
+
#include "cutlass/tensor_ref.h"
|
| 43 |
+
#include "cutlass/tensor_view.h"
|
| 44 |
+
#include "cutlass/layout/pitch_linear.h"
|
| 45 |
+
|
| 46 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace transform {
|
| 50 |
+
|
| 51 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 52 |
+
|
| 53 |
+
/// Strip-mines a pitch-linear tile among a given number of threads, first along
|
| 54 |
+
/// the contiguous dimension then along the strided dimension.
|
| 55 |
+
///
|
| 56 |
+
/// The tile must be divisible by the thread count such that all threads may
|
| 57 |
+
/// execute the same number of iterations with the same delta to exhaustively
|
| 58 |
+
/// cover the tile.
|
| 59 |
+
///
|
| 60 |
+
/// This class satisfies the "RegularThreadMapping" concept.
|
| 61 |
+
///
|
| 62 |
+
/// This ThreadMap is used by SIMT kernels and operand E of the sparse tensor
|
| 63 |
+
/// kernels.
|
| 64 |
+
template <
|
| 65 |
+
typename Shape_,
|
| 66 |
+
int Threads,
|
| 67 |
+
int ElementsPerAccess = 1
|
| 68 |
+
>
|
| 69 |
+
struct PitchLinearStripminedThreadMap {
|
| 70 |
+
|
| 71 |
+
/// Tensor coordinate
|
| 72 |
+
using TensorCoord = layout::PitchLinearCoord;
|
| 73 |
+
|
| 74 |
+
/// Tile shape
|
| 75 |
+
using Shape = Shape_;
|
| 76 |
+
|
| 77 |
+
/// Number of threads total
|
| 78 |
+
static int const kThreads = Threads;
|
| 79 |
+
|
| 80 |
+
/// Extract vector length from Layout
|
| 81 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 82 |
+
|
| 83 |
+
/// Shape of access by each thread
|
| 84 |
+
using ThreadAccessShape = layout::PitchLinearShape<kElementsPerAccess, 1>;
|
| 85 |
+
|
| 86 |
+
/// Internal implementation details
|
| 87 |
+
struct Detail {
|
| 88 |
+
|
| 89 |
+
static_assert(!(Shape::kContiguous % kElementsPerAccess), "");
|
| 90 |
+
|
| 91 |
+
/// Shape of the tile in units of vectors
|
| 92 |
+
using ShapeVec = layout::PitchLinearShape<
|
| 93 |
+
Shape::kContiguous / kElementsPerAccess,
|
| 94 |
+
Shape::kStrided
|
| 95 |
+
>;
|
| 96 |
+
|
| 97 |
+
static_assert((Threads < ShapeVec::kContiguous && !(ShapeVec::kContiguous % kThreads)) ||
|
| 98 |
+
(!(kThreads % ShapeVec::kContiguous)),
|
| 99 |
+
"Shape must be divisible by number of iterations of each thread.");
|
| 100 |
+
};
|
| 101 |
+
|
| 102 |
+
/// Number of iterations by each thread
|
| 103 |
+
using Iterations = typename platform::conditional<
|
| 104 |
+
Threads >= Detail::ShapeVec::kContiguous,
|
| 105 |
+
layout::PitchLinearShape<
|
| 106 |
+
1,
|
| 107 |
+
// Redo the comparison here to work around divide by zero compiler
|
| 108 |
+
// error. The compiler evaluates both path of platform::conditional.
|
| 109 |
+
(Threads >= Detail::ShapeVec::kContiguous
|
| 110 |
+
? (Detail::ShapeVec::kStrided + (kThreads / Detail::ShapeVec::kContiguous - 1)) /
|
| 111 |
+
(kThreads / Detail::ShapeVec::kContiguous)
|
| 112 |
+
: 0)>,
|
| 113 |
+
layout::PitchLinearShape<Detail::ShapeVec::kContiguous / kThreads,
|
| 114 |
+
Detail::ShapeVec::kStrided>>::type;
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
/// Interval between accesses along each dimension of the tensor's logical coordinate space
|
| 118 |
+
/// (in units of Elements)
|
| 119 |
+
using Delta = typename platform::conditional<
|
| 120 |
+
Threads >= Detail::ShapeVec::kContiguous,
|
| 121 |
+
layout::PitchLinearShape<
|
| 122 |
+
1,
|
| 123 |
+
kThreads / Detail::ShapeVec::kContiguous
|
| 124 |
+
>,
|
| 125 |
+
layout::PitchLinearShape<
|
| 126 |
+
kThreads * kElementsPerAccess,
|
| 127 |
+
1
|
| 128 |
+
>
|
| 129 |
+
>::type;
|
| 130 |
+
|
| 131 |
+
/// Shape of the tile in units of vectors
|
| 132 |
+
using StorageShape = typename platform::conditional<
|
| 133 |
+
Threads >= Detail::ShapeVec::kContiguous,
|
| 134 |
+
layout::PitchLinearShape<Shape::kContiguous,
|
| 135 |
+
Iterations::kStrided*(kThreads / Detail::ShapeVec::kContiguous)>,
|
| 136 |
+
layout::PitchLinearShape<Shape::kContiguous, Shape::kStrided>>::type;
|
| 137 |
+
|
| 138 |
+
/// Maps thread ID to a coordinate offset within the tensor's logical coordinate space
|
| 139 |
+
/// (in units of Elements)
|
| 140 |
+
CUTLASS_HOST_DEVICE
|
| 141 |
+
static TensorCoord initial_offset(int thread_id) {
|
| 142 |
+
return TensorCoord(
|
| 143 |
+
(thread_id % Detail::ShapeVec::kContiguous) * kElementsPerAccess,
|
| 144 |
+
thread_id / Detail::ShapeVec::kContiguous);
|
| 145 |
+
}
|
| 146 |
+
};
|
| 147 |
+
|
| 148 |
+
/// This ThreadMap is used by GEMV
|
| 149 |
+
template <
|
| 150 |
+
typename Shape,
|
| 151 |
+
int Threads,
|
| 152 |
+
int ElementsPerAccess = 1
|
| 153 |
+
>
|
| 154 |
+
struct PitchLinearTilePolicyStripminedThreadContiguous
|
| 155 |
+
{
|
| 156 |
+
static_assert((Shape::kContiguous % (Threads * ElementsPerAccess)) == 0,
|
| 157 |
+
"Contiguous shape must divide number of threads");
|
| 158 |
+
|
| 159 |
+
using TensorCoord = layout::PitchLinearCoord;
|
| 160 |
+
|
| 161 |
+
static int const kThreads = Threads;
|
| 162 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 163 |
+
|
| 164 |
+
using Iterations = layout::PitchLinearShape<
|
| 165 |
+
Shape::kContiguous / (kThreads * kElementsPerAccess),
|
| 166 |
+
Shape::kStrided>;
|
| 167 |
+
|
| 168 |
+
using Delta = layout::PitchLinearShape<1, 1>;
|
| 169 |
+
|
| 170 |
+
CUTLASS_HOST_DEVICE
|
| 171 |
+
static TensorCoord initial_offset(int thread_id)
|
| 172 |
+
{
|
| 173 |
+
return TensorCoord(thread_id * Iterations::kContiguous * kElementsPerAccess, 0);
|
| 174 |
+
}
|
| 175 |
+
};
|
| 176 |
+
|
| 177 |
+
template <
|
| 178 |
+
typename Shape,
|
| 179 |
+
int Threads,
|
| 180 |
+
int ElementsPerAccess = 1
|
| 181 |
+
>
|
| 182 |
+
struct PitchLinearTilePolicyStripminedThreadStrided
|
| 183 |
+
{
|
| 184 |
+
static_assert((Shape::kStrided % Threads == 0),
|
| 185 |
+
"Strided shape must divide number of threads");
|
| 186 |
+
|
| 187 |
+
using TensorCoord = layout::PitchLinearCoord;
|
| 188 |
+
|
| 189 |
+
static int const kThreads = Threads;
|
| 190 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 191 |
+
|
| 192 |
+
using Iterations = layout::PitchLinearShape<
|
| 193 |
+
Shape::kContiguous / kElementsPerAccess,
|
| 194 |
+
Shape::kStrided / kThreads>;
|
| 195 |
+
|
| 196 |
+
using Delta = layout::PitchLinearShape<1, 1>;
|
| 197 |
+
|
| 198 |
+
using ShapeVec = Shape;
|
| 199 |
+
|
| 200 |
+
CUTLASS_HOST_DEVICE
|
| 201 |
+
static TensorCoord initial_offset(int thread_id)
|
| 202 |
+
{
|
| 203 |
+
|
| 204 |
+
return TensorCoord(0, thread_id * Iterations::kStrided);
|
| 205 |
+
}
|
| 206 |
+
};
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 210 |
+
|
| 211 |
+
/// Policy defining a warp-raked arrangement in which a shape is partitioned into contiguous
|
| 212 |
+
/// elements.
|
| 213 |
+
///
|
| 214 |
+
/// This ThreadMap is used by tensor core kernels.
|
| 215 |
+
template <
|
| 216 |
+
typename Shape_,
|
| 217 |
+
int Threads,
|
| 218 |
+
typename WarpThreadArrangement_,
|
| 219 |
+
int ElementsPerAccess = 1
|
| 220 |
+
>
|
| 221 |
+
struct PitchLinearWarpRakedThreadMap {
|
| 222 |
+
|
| 223 |
+
/// Tensor coordinate
|
| 224 |
+
using TensorCoord = layout::PitchLinearCoord;
|
| 225 |
+
|
| 226 |
+
/// Tile shape
|
| 227 |
+
using Shape = Shape_;
|
| 228 |
+
|
| 229 |
+
/// Number of threads total
|
| 230 |
+
static int const kThreads = Threads;
|
| 231 |
+
|
| 232 |
+
/// Extract vector length from Layout
|
| 233 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 234 |
+
|
| 235 |
+
/// Shape of access by each thread
|
| 236 |
+
using ThreadAccessShape = layout::PitchLinearShape<kElementsPerAccess, 1>;
|
| 237 |
+
|
| 238 |
+
/// Internal details made public to facilitate introspection
|
| 239 |
+
struct Detail {
|
| 240 |
+
|
| 241 |
+
/// Fixed arrangement of threads within a warp (units of threads).
|
| 242 |
+
using WarpThreadArrangement = WarpThreadArrangement_;
|
| 243 |
+
|
| 244 |
+
/// Number of threads per warp
|
| 245 |
+
static int const kWarpSize = WarpThreadArrangement::kCount;
|
| 246 |
+
|
| 247 |
+
/// Number of participating warps
|
| 248 |
+
static int const kWarpCount = kThreads / kWarpSize;
|
| 249 |
+
|
| 250 |
+
static_assert(
|
| 251 |
+
!(Shape::kContiguous % kElementsPerAccess),
|
| 252 |
+
"Shape must be divisible by vector length.");
|
| 253 |
+
|
| 254 |
+
/// Compute the 'shape' of the overall tile in units of vectors
|
| 255 |
+
using ShapeInAccesses = layout::PitchLinearShape<
|
| 256 |
+
Shape::kContiguous / kElementsPerAccess,
|
| 257 |
+
Shape::kStrided
|
| 258 |
+
>;
|
| 259 |
+
|
| 260 |
+
static_assert(
|
| 261 |
+
!(ShapeInAccesses::kContiguous % WarpThreadArrangement::kContiguous),
|
| 262 |
+
"ShapeInAccesses must be divisible by WarpThreadArrangement.");
|
| 263 |
+
|
| 264 |
+
static_assert(
|
| 265 |
+
!(ShapeInAccesses::kStrided % WarpThreadArrangement::kStrided),
|
| 266 |
+
"ShapeInAccesses must be divisible by WarpThreadArrangement.");
|
| 267 |
+
|
| 268 |
+
// compute number of warp-level accesses total
|
| 269 |
+
using WarpAccessIterations = layout::PitchLinearShape<
|
| 270 |
+
ShapeInAccesses::kContiguous / WarpThreadArrangement::kContiguous,
|
| 271 |
+
ShapeInAccesses::kStrided / WarpThreadArrangement::kStrided
|
| 272 |
+
>;
|
| 273 |
+
|
| 274 |
+
// Divide it into the number of warps, first partitioning the strided dimension then the
|
| 275 |
+
// contiguous.
|
| 276 |
+
static int const kWarpsStrided =
|
| 277 |
+
(WarpAccessIterations::kStrided >= kWarpCount
|
| 278 |
+
? kWarpCount
|
| 279 |
+
: WarpAccessIterations::kStrided);
|
| 280 |
+
|
| 281 |
+
static int const kWarpsContiguous =
|
| 282 |
+
(kWarpCount > WarpAccessIterations::kStrided
|
| 283 |
+
? kWarpCount / kWarpsStrided
|
| 284 |
+
: 1);
|
| 285 |
+
|
| 286 |
+
/// Arrangement of warps within a threadblock-scoped tile
|
| 287 |
+
using WarpArrangement = layout::PitchLinearShape<
|
| 288 |
+
kWarpsContiguous, kWarpsStrided
|
| 289 |
+
>;
|
| 290 |
+
};
|
| 291 |
+
|
| 292 |
+
///< Iterations along each dimension (concept: PitchLinearShape)
|
| 293 |
+
using Iterations = layout::PitchLinearShape<
|
| 294 |
+
Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous,
|
| 295 |
+
Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided
|
| 296 |
+
>;
|
| 297 |
+
|
| 298 |
+
static_assert(Iterations::kCount,
|
| 299 |
+
"Number of iterations must be non-zero");
|
| 300 |
+
|
| 301 |
+
///< Delta between accesses (units of elements, concept: PitchLinearShape)
|
| 302 |
+
using Delta = layout::PitchLinearShape<
|
| 303 |
+
Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess,
|
| 304 |
+
Detail::WarpThreadArrangement::kStrided
|
| 305 |
+
>;
|
| 306 |
+
|
| 307 |
+
/// Maps thread ID to a coordinate offset within the tensor's logical coordinate space
|
| 308 |
+
CUTLASS_HOST_DEVICE
|
| 309 |
+
static TensorCoord initial_offset(int thread_id) {
|
| 310 |
+
|
| 311 |
+
int warp_id = (thread_id / Detail::kWarpSize);
|
| 312 |
+
int lane_id = (thread_id % Detail::kWarpSize);
|
| 313 |
+
|
| 314 |
+
//
|
| 315 |
+
// compute warp-level offset
|
| 316 |
+
//
|
| 317 |
+
|
| 318 |
+
// This is the shape of the entire area covered by a warp's memory access (in units of vectors)
|
| 319 |
+
layout::PitchLinearCoord warp_footprint{
|
| 320 |
+
Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous,
|
| 321 |
+
Detail::WarpThreadArrangement::kStrided * Iterations::kStrided
|
| 322 |
+
};
|
| 323 |
+
|
| 324 |
+
// This is the offset of a specific warp (in units of vectors)
|
| 325 |
+
layout::PitchLinearCoord warp_offset{
|
| 326 |
+
(warp_id % Detail::kWarpsContiguous),
|
| 327 |
+
(warp_id / Detail::kWarpsContiguous)
|
| 328 |
+
};
|
| 329 |
+
|
| 330 |
+
// This is the offset of a specific thread within a warp (units of vectors)
|
| 331 |
+
layout::PitchLinearCoord thread_offset_in_warp{
|
| 332 |
+
lane_id % Detail::WarpThreadArrangement::kContiguous,
|
| 333 |
+
lane_id / Detail::WarpThreadArrangement::kContiguous
|
| 334 |
+
};
|
| 335 |
+
|
| 336 |
+
// This is the offset of a thread within a threadblock tile (units of vectors)
|
| 337 |
+
layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec =
|
| 338 |
+
warp_footprint * warp_offset + thread_offset_in_warp;
|
| 339 |
+
|
| 340 |
+
// This is the offset of a thread within a threadblock tile (units of elements)
|
| 341 |
+
layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{
|
| 342 |
+
thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess,
|
| 343 |
+
thread_offset_in_threadblock_tile_vec.strided()
|
| 344 |
+
};
|
| 345 |
+
|
| 346 |
+
return thread_offset_in_threadblock_tile_base;
|
| 347 |
+
}
|
| 348 |
+
};
|
| 349 |
+
|
| 350 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 351 |
+
|
| 352 |
+
/// Policy defining a warp-raked arrangement in which a shape is partitioned into contiguous
|
| 353 |
+
/// elements. Warps are arranged based on a stride.
|
| 354 |
+
///
|
| 355 |
+
/// This ThreadMap is used by tensor core kernels for NCxHWx layout.
|
| 356 |
+
template <
|
| 357 |
+
typename Shape_,
|
| 358 |
+
int Threads,
|
| 359 |
+
typename WarpThreadArrangement_,
|
| 360 |
+
int ElementsPerAccess = 1
|
| 361 |
+
>
|
| 362 |
+
struct PitchLinearStridedWarpRakedThreadMap {
|
| 363 |
+
|
| 364 |
+
/// Tensor coordinate
|
| 365 |
+
using TensorCoord = layout::PitchLinearCoord;
|
| 366 |
+
|
| 367 |
+
/// Tile shape
|
| 368 |
+
using Shape = Shape_;
|
| 369 |
+
|
| 370 |
+
/// Number of threads total
|
| 371 |
+
static int const kThreads = Threads;
|
| 372 |
+
|
| 373 |
+
using WarpThreadArrangement = WarpThreadArrangement_;
|
| 374 |
+
|
| 375 |
+
/// Extract vector length from Layout
|
| 376 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 377 |
+
|
| 378 |
+
/// Base ThreadMap
|
| 379 |
+
using BaseThreadMap = PitchLinearWarpRakedThreadMap<
|
| 380 |
+
Shape,
|
| 381 |
+
kThreads,
|
| 382 |
+
WarpThreadArrangement,
|
| 383 |
+
kElementsPerAccess
|
| 384 |
+
>;
|
| 385 |
+
|
| 386 |
+
/// Shape of access by each thread
|
| 387 |
+
using ThreadAccessShape = typename BaseThreadMap::ThreadAccessShape;
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
struct Detail {
|
| 391 |
+
|
| 392 |
+
using WarpThreadArrangement = WarpThreadArrangement_;
|
| 393 |
+
|
| 394 |
+
using WarpAccessIterations = typename BaseThreadMap::Detail::WarpAccessIterations;
|
| 395 |
+
|
| 396 |
+
static int const kWarpSize = BaseThreadMap::Detail::kWarpSize;
|
| 397 |
+
|
| 398 |
+
static int const kWarpCount = BaseThreadMap::Detail::kWarpCount;
|
| 399 |
+
|
| 400 |
+
using ShapeInAccesses = typename BaseThreadMap::Detail::ShapeInAccesses;
|
| 401 |
+
|
| 402 |
+
// Divide it into the number of warps, first partitioning the contiguous dimension then the
|
| 403 |
+
// stride.
|
| 404 |
+
static int const kWarpsContiguous =
|
| 405 |
+
(WarpAccessIterations::kContiguous >= kWarpCount
|
| 406 |
+
? kWarpCount
|
| 407 |
+
: WarpAccessIterations::kContiguous);
|
| 408 |
+
|
| 409 |
+
static int const kWarpsStrided =
|
| 410 |
+
(kWarpCount > WarpAccessIterations::kContiguous
|
| 411 |
+
? kWarpCount / kWarpsContiguous
|
| 412 |
+
: 1);
|
| 413 |
+
|
| 414 |
+
/// Arrangement of warps within a threadblock-scoped tile
|
| 415 |
+
using WarpArrangement = layout::PitchLinearShape<
|
| 416 |
+
kWarpsContiguous, kWarpsStrided
|
| 417 |
+
>;
|
| 418 |
+
|
| 419 |
+
};
|
| 420 |
+
|
| 421 |
+
///< Iterations along each dimension (concept: PitchLinearShape)
|
| 422 |
+
using Iterations = layout::PitchLinearShape<
|
| 423 |
+
Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous,
|
| 424 |
+
Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided
|
| 425 |
+
>;
|
| 426 |
+
|
| 427 |
+
static_assert(Iterations::kCount,
|
| 428 |
+
"Number of iterations must be non-zero");
|
| 429 |
+
|
| 430 |
+
///< Delta between accesses (units of elements, concept: PitchLinearShape)
|
| 431 |
+
using Delta = typename BaseThreadMap::Delta;
|
| 432 |
+
|
| 433 |
+
/// Maps thread ID to a coordinate offset within the tensor's logical coordinate space
|
| 434 |
+
CUTLASS_HOST_DEVICE
|
| 435 |
+
static TensorCoord initial_offset(int thread_id) {
|
| 436 |
+
|
| 437 |
+
int warp_id = (thread_id / Detail::kWarpSize);
|
| 438 |
+
int lane_id = (thread_id % Detail::kWarpSize);
|
| 439 |
+
|
| 440 |
+
//
|
| 441 |
+
// compute warp-level offset
|
| 442 |
+
//
|
| 443 |
+
|
| 444 |
+
// This is the shape of the entire area covered by a warp's memory access (in units of vectors)
|
| 445 |
+
layout::PitchLinearCoord warp_footprint{
|
| 446 |
+
Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous,
|
| 447 |
+
Detail::WarpThreadArrangement::kStrided * Iterations::kStrided
|
| 448 |
+
};
|
| 449 |
+
|
| 450 |
+
// This is the offset of a specific warp (in units of vectors)
|
| 451 |
+
layout::PitchLinearCoord warp_offset{
|
| 452 |
+
(warp_id % Detail::kWarpsContiguous),
|
| 453 |
+
(warp_id / Detail::kWarpsContiguous)
|
| 454 |
+
};
|
| 455 |
+
|
| 456 |
+
// This is the offset of a specific thread within a warp (units of vectors)
|
| 457 |
+
layout::PitchLinearCoord thread_offset_in_warp{
|
| 458 |
+
lane_id % Detail::WarpThreadArrangement::kContiguous,
|
| 459 |
+
lane_id / Detail::WarpThreadArrangement::kContiguous
|
| 460 |
+
};
|
| 461 |
+
|
| 462 |
+
// This is the offset of a thread within a threadblock tile (units of vectors)
|
| 463 |
+
layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec =
|
| 464 |
+
warp_footprint * warp_offset + thread_offset_in_warp;
|
| 465 |
+
|
| 466 |
+
// This is the offset of a thread within a threadblock tile (units of elements)
|
| 467 |
+
layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{
|
| 468 |
+
thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess,
|
| 469 |
+
thread_offset_in_threadblock_tile_vec.strided()
|
| 470 |
+
};
|
| 471 |
+
|
| 472 |
+
return thread_offset_in_threadblock_tile_base;
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
};
|
| 477 |
+
|
| 478 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 479 |
+
|
| 480 |
+
/// Transpose the existing ThreadMap. For example, interleaved layout is like
|
| 481 |
+
/// congruous in the global memory and crosswise in the shared memory. We need
|
| 482 |
+
/// to transpose the coordinates between two.
|
| 483 |
+
|
| 484 |
+
template <typename ThreadMap_, typename WarpThreadArrangement_>
|
| 485 |
+
struct TransposePitchLinearThreadMap {
|
| 486 |
+
/// Underlying ThreadMap
|
| 487 |
+
using ThreadMap = ThreadMap_;
|
| 488 |
+
|
| 489 |
+
/// Tensor coordinate
|
| 490 |
+
using TensorCoord = typename ThreadMap::TensorCoord;
|
| 491 |
+
|
| 492 |
+
/// Tile shape
|
| 493 |
+
using Shape = typename ThreadMap::Shape;
|
| 494 |
+
|
| 495 |
+
/// Number of threads total
|
| 496 |
+
static int const kThreads = ThreadMap::kThreads;
|
| 497 |
+
|
| 498 |
+
/// Extract vector length from Layout
|
| 499 |
+
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
| 500 |
+
|
| 501 |
+
/// Shape of access by each thread
|
| 502 |
+
using ThreadAccessShape = layout::PitchLinearShape<kElementsPerAccess, 1>;
|
| 503 |
+
|
| 504 |
+
/// Internal details made public to facilitate introspection
|
| 505 |
+
struct Detail {
|
| 506 |
+
/// Fixed arrangement of threads within a warp (units of threads).
|
| 507 |
+
using WarpThreadArrangement = WarpThreadArrangement_;
|
| 508 |
+
|
| 509 |
+
/// Number of threads per warp
|
| 510 |
+
static int const kWarpSize = WarpThreadArrangement::kCount;
|
| 511 |
+
|
| 512 |
+
/// Number of participating warps
|
| 513 |
+
static int const kWarpCount = kThreads / kWarpSize;
|
| 514 |
+
|
| 515 |
+
static_assert(!(Shape::kContiguous % kElementsPerAccess),
|
| 516 |
+
"Shape must be divisible by vector length.");
|
| 517 |
+
|
| 518 |
+
/// Arrangement of warps within a threadblock-scoped tile
|
| 519 |
+
using WarpArrangement =
|
| 520 |
+
layout::PitchLinearShape<ThreadMap::Detail::kWarpsStrided,
|
| 521 |
+
ThreadMap::Detail::kWarpsContiguous>;
|
| 522 |
+
};
|
| 523 |
+
|
| 524 |
+
///< Iterations along each dimension (concept: PitchLinearShape)
|
| 525 |
+
using Iterations =
|
| 526 |
+
layout::PitchLinearShape<ThreadMap::Iterations::kStrided,
|
| 527 |
+
ThreadMap::Iterations::kContiguous>;
|
| 528 |
+
|
| 529 |
+
static_assert(Iterations::kContiguous == 1,
|
| 530 |
+
"Contiguous iteration has to be one to reuse the same shared store function with those that don't need transpose");
|
| 531 |
+
|
| 532 |
+
static_assert(Iterations::kCount, "Number of iterations must be non-zero");
|
| 533 |
+
|
| 534 |
+
///< Delta between accesses (units of elements, concept: PitchLinearShape)
|
| 535 |
+
using Delta =
|
| 536 |
+
layout::PitchLinearShape<Detail::WarpThreadArrangement::kContiguous *
|
| 537 |
+
kElementsPerAccess,
|
| 538 |
+
Detail::WarpThreadArrangement::kStrided>;
|
| 539 |
+
|
| 540 |
+
/// Maps thread ID to a coordinate offset within the tensor's logical
|
| 541 |
+
/// coordinate space Note this is slightly different from the one of
|
| 542 |
+
/// PitchLinearWarpRakedThreadMap.
|
| 543 |
+
CUTLASS_HOST_DEVICE
|
| 544 |
+
static TensorCoord initial_offset(int thread_id) {
|
| 545 |
+
|
| 546 |
+
int warp_id = (thread_id / Detail::kWarpSize);
|
| 547 |
+
int lane_id = (thread_id % Detail::kWarpSize);
|
| 548 |
+
|
| 549 |
+
//
|
| 550 |
+
// compute warp-level offset
|
| 551 |
+
//
|
| 552 |
+
|
| 553 |
+
// This is the shape of the entire area covered by a warp's memory access
|
| 554 |
+
// (in units of vectors)
|
| 555 |
+
layout::PitchLinearCoord warp_footprint{
|
| 556 |
+
Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous,
|
| 557 |
+
Detail::WarpThreadArrangement::kStrided * Iterations::kStrided};
|
| 558 |
+
|
| 559 |
+
// This is the offset of a specific warp (in units of vectors)
|
| 560 |
+
// Note the order of / and %. Also the 2nd operand is kStrided.
|
| 561 |
+
layout::PitchLinearCoord warp_offset{
|
| 562 |
+
(warp_id / Detail::WarpArrangement::kStrided),
|
| 563 |
+
(warp_id % Detail::WarpArrangement::kStrided)};
|
| 564 |
+
|
| 565 |
+
// This is the offset of a specific thread within a warp (units of vectors)
|
| 566 |
+
layout::PitchLinearCoord thread_offset_in_warp{
|
| 567 |
+
lane_id % Detail::WarpThreadArrangement::kContiguous,
|
| 568 |
+
lane_id / Detail::WarpThreadArrangement::kContiguous};
|
| 569 |
+
|
| 570 |
+
// This is the offset of a thread within a threadblock tile (units of
|
| 571 |
+
// vectors)
|
| 572 |
+
layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec =
|
| 573 |
+
warp_footprint * warp_offset + thread_offset_in_warp;
|
| 574 |
+
|
| 575 |
+
// This is the offset of a thread within a threadblock tile (units of
|
| 576 |
+
// elements)
|
| 577 |
+
layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{
|
| 578 |
+
thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess,
|
| 579 |
+
thread_offset_in_threadblock_tile_vec.strided()};
|
| 580 |
+
|
| 581 |
+
return thread_offset_in_threadblock_tile_base;
|
| 582 |
+
}
|
| 583 |
+
};
|
| 584 |
+
|
| 585 |
+
template <typename ThreadMap_>
|
| 586 |
+
struct TransposePitchLinearThreadMapSimt {
|
| 587 |
+
/// Underlying ThreadMap
|
| 588 |
+
using ThreadMap = ThreadMap_;
|
| 589 |
+
|
| 590 |
+
/// Tensor coordinate
|
| 591 |
+
using TensorCoord = typename ThreadMap::TensorCoord;
|
| 592 |
+
|
| 593 |
+
/// Tile shape
|
| 594 |
+
using Shape = typename ThreadMap::Shape;
|
| 595 |
+
|
| 596 |
+
/// Number of threads total
|
| 597 |
+
static int const kThreads = ThreadMap::kThreads;
|
| 598 |
+
|
| 599 |
+
/// Extract vector length from Layout
|
| 600 |
+
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
| 601 |
+
|
| 602 |
+
static_assert(kElementsPerAccess == 1 , "Simt transpose requires elements per access to be 1");
|
| 603 |
+
///< Iterations along each dimension (concept: PitchLinearShape)
|
| 604 |
+
using Iterations =
|
| 605 |
+
layout::PitchLinearShape<ThreadMap::Iterations::kStrided,
|
| 606 |
+
ThreadMap::Iterations::kContiguous>;
|
| 607 |
+
|
| 608 |
+
static_assert(Iterations::kCount, "Number of iterations must be non-zero");
|
| 609 |
+
|
| 610 |
+
static_assert(Iterations::kStrided == 1,
|
| 611 |
+
"Strided iteration has to be one to reuse the same shared store function with those that don't need transpose");
|
| 612 |
+
|
| 613 |
+
/// Shape of access by each thread
|
| 614 |
+
using ThreadAccessShape = typename ThreadMap::ThreadAccessShape;
|
| 615 |
+
|
| 616 |
+
///< Delta between accesses (units of elements, concept: PitchLinearShape)
|
| 617 |
+
using Delta =
|
| 618 |
+
layout::PitchLinearShape<ThreadMap::Delta::kStrided,
|
| 619 |
+
ThreadMap::Delta::kContiguous>;
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
/// Maps thread ID to a coordinate offset within the tensor's logical
|
| 623 |
+
/// coordinate space Note this is slightly different from the one of
|
| 624 |
+
/// PitchLinearWarpRakedThreadMap.
|
| 625 |
+
CUTLASS_HOST_DEVICE
|
| 626 |
+
static TensorCoord initial_offset(int thread_id) {
|
| 627 |
+
|
| 628 |
+
TensorCoord coord = ThreadMap::initial_offset(thread_id);
|
| 629 |
+
|
| 630 |
+
return TensorCoord(
|
| 631 |
+
coord.strided(),
|
| 632 |
+
coord.contiguous()
|
| 633 |
+
);
|
| 634 |
+
}
|
| 635 |
+
};
|
| 636 |
+
|
| 637 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
/// Policy defining a warp-striped arrangement. This partitions a tile into vectorized memory
|
| 641 |
+
/// accesses performed by each warp then distributes warps across them. Warps are striped in the
|
| 642 |
+
/// strided dimension and raked across the contiguous dimension.
|
| 643 |
+
template <
|
| 644 |
+
typename Shape_, /// Overall shape to partition in units of elements
|
| 645 |
+
int Threads, /// Number of partiticipation threads
|
| 646 |
+
typename WarpThreadArrangement_, /// Describes the shape of one memory access per warp
|
| 647 |
+
int ElementsPerAccess = 1 /// Number of elements accessed by each thread per memory operation (i.e. vector size)
|
| 648 |
+
>
|
| 649 |
+
struct PitchLinearWarpStripedThreadMap {
|
| 650 |
+
|
| 651 |
+
/// Tensor coordinate
|
| 652 |
+
using TensorCoord = layout::PitchLinearCoord;
|
| 653 |
+
|
| 654 |
+
/// Tile shape
|
| 655 |
+
using Shape = Shape_;
|
| 656 |
+
|
| 657 |
+
/// Number of threads total
|
| 658 |
+
static int const kThreads = Threads;
|
| 659 |
+
|
| 660 |
+
/// Extract vector length from Layout
|
| 661 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 662 |
+
|
| 663 |
+
/// Shape of access by each thread
|
| 664 |
+
using ThreadAccessShape = layout::PitchLinearShape<kElementsPerAccess, 1>;
|
| 665 |
+
|
| 666 |
+
/// Internal details made public to facilitate introspection
|
| 667 |
+
struct Detail {
|
| 668 |
+
|
| 669 |
+
/// Fixed arrangement of threads within a warp (units of threads).
|
| 670 |
+
using WarpThreadArrangement = WarpThreadArrangement_;
|
| 671 |
+
|
| 672 |
+
/// Number of threads per warp
|
| 673 |
+
static int const kWarpSize = WarpThreadArrangement::kCount;
|
| 674 |
+
|
| 675 |
+
/// Number of participating warps
|
| 676 |
+
static int const kWarpCount = kThreads / kWarpSize;
|
| 677 |
+
|
| 678 |
+
static_assert(
|
| 679 |
+
!(Shape::kContiguous % kElementsPerAccess),
|
| 680 |
+
"Shape must be divisible by vector length.");
|
| 681 |
+
|
| 682 |
+
/// Compute the 'shape' of the overall tile in units of vectors
|
| 683 |
+
using ShapeInAccesses = layout::PitchLinearShape<
|
| 684 |
+
Shape::kContiguous / kElementsPerAccess,
|
| 685 |
+
Shape::kStrided
|
| 686 |
+
>;
|
| 687 |
+
|
| 688 |
+
// compute number of warp-level accesses total
|
| 689 |
+
using WarpAccessIterations = layout::PitchLinearShape<
|
| 690 |
+
ShapeInAccesses::kContiguous / WarpThreadArrangement::kContiguous,
|
| 691 |
+
ShapeInAccesses::kStrided / WarpThreadArrangement::kStrided
|
| 692 |
+
>;
|
| 693 |
+
|
| 694 |
+
// Divide it into the number of warps, first partitioning the strided dimension then the
|
| 695 |
+
// contiguous.
|
| 696 |
+
static int const kWarpsStrided =
|
| 697 |
+
(WarpAccessIterations::kStrided >= kWarpCount
|
| 698 |
+
? kWarpCount : (kWarpCount / WarpAccessIterations::kStrided));
|
| 699 |
+
|
| 700 |
+
static int const kWarpsContiguous =
|
| 701 |
+
(kWarpCount > WarpAccessIterations::kStrided ?
|
| 702 |
+
WarpAccessIterations::kContiguous / kWarpsStrided : 1);
|
| 703 |
+
|
| 704 |
+
/// Arrangement of warps within a threadblock-scoped tile
|
| 705 |
+
using WarpArrangement = layout::PitchLinearShape<
|
| 706 |
+
kWarpsContiguous, kWarpsStrided
|
| 707 |
+
>;
|
| 708 |
+
};
|
| 709 |
+
|
| 710 |
+
///< Iterations along each dimension (concept: PitchLinearShape)
|
| 711 |
+
using Iterations = layout::PitchLinearShape<
|
| 712 |
+
Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous,
|
| 713 |
+
Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided
|
| 714 |
+
>;
|
| 715 |
+
|
| 716 |
+
static_assert(Iterations::kCount,
|
| 717 |
+
"Number of iterations must be non-zero");
|
| 718 |
+
|
| 719 |
+
///< Delta between accesses (units of elements, concept: PitchLinearShape)
|
| 720 |
+
using Delta = layout::PitchLinearShape<
|
| 721 |
+
Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess,
|
| 722 |
+
Detail::WarpThreadArrangement::kStrided * Detail::WarpArrangement::kStrided
|
| 723 |
+
>;
|
| 724 |
+
|
| 725 |
+
/// Maps thread ID to a coordinate offset within the tensor's logical coordinate space
|
| 726 |
+
CUTLASS_HOST_DEVICE
|
| 727 |
+
static TensorCoord initial_offset(int thread_id) {
|
| 728 |
+
|
| 729 |
+
int warp_id = (thread_id / Detail::kWarpSize);
|
| 730 |
+
int lane_id = (thread_id % Detail::kWarpSize);
|
| 731 |
+
|
| 732 |
+
//
|
| 733 |
+
// compute warp-level offset
|
| 734 |
+
//
|
| 735 |
+
|
| 736 |
+
// This is the shape of the entire area covered by a warp's memory access (in units of vectors)
|
| 737 |
+
layout::PitchLinearCoord warp_footprint{
|
| 738 |
+
Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous,
|
| 739 |
+
Detail::WarpThreadArrangement::kStrided
|
| 740 |
+
};
|
| 741 |
+
|
| 742 |
+
// This is the offset of a specific warp (in units of vectors)
|
| 743 |
+
layout::PitchLinearCoord warp_offset{
|
| 744 |
+
(warp_id % Detail::kWarpsContiguous),
|
| 745 |
+
(warp_id / Detail::kWarpsContiguous)
|
| 746 |
+
};
|
| 747 |
+
|
| 748 |
+
// This is the offset of a specific thread within a warp (units of vectors)
|
| 749 |
+
layout::PitchLinearCoord thread_offset_in_warp{
|
| 750 |
+
lane_id % Detail::WarpThreadArrangement::kContiguous,
|
| 751 |
+
lane_id / Detail::WarpThreadArrangement::kContiguous
|
| 752 |
+
};
|
| 753 |
+
|
| 754 |
+
// This is the offset of a thread within a threadblock tile (units of vectors)
|
| 755 |
+
layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec =
|
| 756 |
+
warp_footprint * warp_offset + thread_offset_in_warp;
|
| 757 |
+
|
| 758 |
+
// This is the offset of a thread within a threadblock tile (units of elements)
|
| 759 |
+
layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{
|
| 760 |
+
thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess,
|
| 761 |
+
thread_offset_in_threadblock_tile_vec.strided()
|
| 762 |
+
};
|
| 763 |
+
|
| 764 |
+
return thread_offset_in_threadblock_tile_base;
|
| 765 |
+
}
|
| 766 |
+
};
|
| 767 |
+
|
| 768 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 769 |
+
/// Strip-mines a pitch-linear tile among a given number of threads, first along the contiguous
|
| 770 |
+
/// dimension then along the strided dimension, while each thread access a 2D thread-tile.
|
| 771 |
+
///
|
| 772 |
+
/// The tile must be divisible by the thread count such that all threads may execute the same
|
| 773 |
+
/// number of iterations with the same delta to exhaustively cover the tile.
|
| 774 |
+
///
|
| 775 |
+
/// This class satisfies the "RegularThreadMapping" concept.
|
| 776 |
+
template <
|
| 777 |
+
typename Shape_,
|
| 778 |
+
int Threads,
|
| 779 |
+
typename ThreadTileShape
|
| 780 |
+
>
|
| 781 |
+
struct PitchLinear2DThreadTileStripminedThreadMap;
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
template <
|
| 785 |
+
typename Shape_,
|
| 786 |
+
int Threads
|
| 787 |
+
>
|
| 788 |
+
struct PitchLinear2DThreadTileStripminedThreadMap <Shape_, Threads, cutlass::layout::PitchLinearShape<4, 4>>{
|
| 789 |
+
|
| 790 |
+
/// Tensor coordinate
|
| 791 |
+
using TensorCoord = layout::PitchLinearCoord;
|
| 792 |
+
|
| 793 |
+
/// Tile shape
|
| 794 |
+
using Shape = Shape_;
|
| 795 |
+
|
| 796 |
+
/// Access Shape of each thread
|
| 797 |
+
using ThreadAccessShape = cutlass::layout::PitchLinearShape<4, 4>;
|
| 798 |
+
//using ThreadAccessShape = ThreadTileShape;
|
| 799 |
+
|
| 800 |
+
/// Number of threads total
|
| 801 |
+
static int const kThreads = Threads;
|
| 802 |
+
|
| 803 |
+
/// Extract length of each access from Layout
|
| 804 |
+
static int const kElementsPerAccess = ThreadAccessShape::kContiguous;
|
| 805 |
+
|
| 806 |
+
static_assert(!(kElementsPerAccess % 4) , "kElementsPerAccess, needs to be multiple of 4 (32bits)");
|
| 807 |
+
|
| 808 |
+
/// Internal implementation details
|
| 809 |
+
struct Detail {
|
| 810 |
+
|
| 811 |
+
static_assert(!(ThreadAccessShape::kContiguous % 4), "ThreadAccessShape, needs to be multiple of 4");
|
| 812 |
+
|
| 813 |
+
static_assert(!(Shape::kContiguous % ThreadAccessShape::kContiguous), "");
|
| 814 |
+
|
| 815 |
+
static_assert(!((Shape::kContiguous * Shape::kStrided) % (kThreads * ThreadAccessShape::kCount)),
|
| 816 |
+
"Shape must be divisible thread count * accesses per thread.");
|
| 817 |
+
|
| 818 |
+
/// Shape of the tile in units of vectors
|
| 819 |
+
using ShapeVec = layout::PitchLinearShape<
|
| 820 |
+
Shape::kContiguous / ThreadAccessShape::kContiguous,
|
| 821 |
+
Shape::kStrided / ThreadAccessShape::kStrided
|
| 822 |
+
>;
|
| 823 |
+
|
| 824 |
+
static_assert(
|
| 825 |
+
(Threads < ShapeVec::kContiguous && !(ShapeVec::kContiguous % kThreads)) ||
|
| 826 |
+
(!(kThreads % ShapeVec::kContiguous) && !(ShapeVec::kStrided % (kThreads / ShapeVec::kContiguous))),
|
| 827 |
+
"Shape must be divisible by number of iterations of each thread."
|
| 828 |
+
);
|
| 829 |
+
};
|
| 830 |
+
|
| 831 |
+
/// Number of iterations by each thread
|
| 832 |
+
using Iterations = typename platform::conditional<
|
| 833 |
+
Threads >= Detail::ShapeVec::kContiguous,
|
| 834 |
+
layout::PitchLinearShape<
|
| 835 |
+
1,
|
| 836 |
+
// Redo the comparison here to work around divide by zero compiler
|
| 837 |
+
// error. The compiler evaluates both path of platform::conditional.
|
| 838 |
+
(Threads >= Detail::ShapeVec::kContiguous
|
| 839 |
+
? Detail::ShapeVec::kStrided /
|
| 840 |
+
(kThreads / Detail::ShapeVec::kContiguous)
|
| 841 |
+
: 0)>,
|
| 842 |
+
layout::PitchLinearShape<Detail::ShapeVec::kContiguous / kThreads,
|
| 843 |
+
Detail::ShapeVec::kStrided>>::type;
|
| 844 |
+
|
| 845 |
+
/// Interval between accesses along each dimension of the tensor's logical coordinate space
|
| 846 |
+
/// (in units of Elements)
|
| 847 |
+
using Delta = typename platform::conditional<
|
| 848 |
+
Threads >= Detail::ShapeVec::kContiguous,
|
| 849 |
+
layout::PitchLinearShape<
|
| 850 |
+
Shape::kContiguous,
|
| 851 |
+
kThreads * ThreadAccessShape::kStrided / Detail::ShapeVec::kContiguous
|
| 852 |
+
>,
|
| 853 |
+
layout::PitchLinearShape<
|
| 854 |
+
kThreads * ThreadAccessShape::kContiguous,
|
| 855 |
+
1
|
| 856 |
+
>
|
| 857 |
+
>::type;
|
| 858 |
+
|
| 859 |
+
/// Maps thread ID to a coordinate offset within the tensor's logical coordinate space
|
| 860 |
+
/// (in units of Elements)
|
| 861 |
+
CUTLASS_HOST_DEVICE
|
| 862 |
+
static TensorCoord initial_offset(int thread_id) {
|
| 863 |
+
|
| 864 |
+
return TensorCoord(
|
| 865 |
+
(thread_id % Detail::ShapeVec::kContiguous) * ThreadAccessShape::kContiguous,
|
| 866 |
+
(thread_id / Detail::ShapeVec::kContiguous) * ThreadAccessShape::kStrided);
|
| 867 |
+
}
|
| 868 |
+
};
|
| 869 |
+
|
| 870 |
+
/// Thread Mapping a 2D threadtiled mapping as a transposed Pitchlinear2DThreadTile mapping
|
| 871 |
+
template <typename ThreadMap_>
|
| 872 |
+
struct TransposePitchLinearThreadMap2DThreadTile {
|
| 873 |
+
/// Underlying ThreadMap
|
| 874 |
+
using ThreadMap = ThreadMap_;
|
| 875 |
+
|
| 876 |
+
/// Tensor coordinate
|
| 877 |
+
using TensorCoord = typename ThreadMap::TensorCoord;
|
| 878 |
+
|
| 879 |
+
/// Tile shape
|
| 880 |
+
using Shape = typename ThreadMap::Shape;
|
| 881 |
+
|
| 882 |
+
/// Number of threads total
|
| 883 |
+
static int const kThreads = ThreadMap::kThreads;
|
| 884 |
+
|
| 885 |
+
/// Extract vector length from Layout
|
| 886 |
+
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
| 887 |
+
|
| 888 |
+
|
| 889 |
+
static_assert(kElementsPerAccess > 1 , "Simt transpose requires elements per access to be 1");
|
| 890 |
+
///< Iterations along each dimension (concept: PitchLinearShape)
|
| 891 |
+
using Iterations =
|
| 892 |
+
layout::PitchLinearShape<ThreadMap::Iterations::kStrided,
|
| 893 |
+
ThreadMap::Iterations::kContiguous>;
|
| 894 |
+
|
| 895 |
+
static_assert(Iterations::kCount, "Number of iterations must be non-zero");
|
| 896 |
+
|
| 897 |
+
/// Shape of access by each thread
|
| 898 |
+
using ThreadAccessShape = typename ThreadMap::ThreadAccessShape;
|
| 899 |
+
|
| 900 |
+
///< Delta between accesses (units of elements, concept: PitchLinearShape)
|
| 901 |
+
using Delta =
|
| 902 |
+
layout::PitchLinearShape<ThreadMap::Delta::kStrided,
|
| 903 |
+
ThreadMap::Delta::kContiguous>;
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
/// Maps thread ID to a coordinate offset within the tensor's logical
|
| 907 |
+
/// coordinate space Note this is slightly different from the one of
|
| 908 |
+
/// PitchLinearWarpRakedThreadMap.
|
| 909 |
+
CUTLASS_HOST_DEVICE
|
| 910 |
+
static TensorCoord initial_offset(int thread_id) {
|
| 911 |
+
|
| 912 |
+
TensorCoord coord = ThreadMap::initial_offset(thread_id);
|
| 913 |
+
return TensorCoord(
|
| 914 |
+
coord.strided(),
|
| 915 |
+
coord.contiguous()
|
| 916 |
+
);
|
| 917 |
+
}
|
| 918 |
+
};
|
| 919 |
+
|
| 920 |
+
|
| 921 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 922 |
+
|
| 923 |
+
} // namespace transform
|
| 924 |
+
} // namespace cutlass
|
| 925 |
+
|
| 926 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/thread/transpose.h
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Basic copy routines for tensor views
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
namespace cutlass {
|
| 39 |
+
namespace transform {
|
| 40 |
+
namespace thread {
|
| 41 |
+
|
| 42 |
+
/// Transforms a fragment by doing a transpose
|
| 43 |
+
template <
|
| 44 |
+
int ElementCount,
|
| 45 |
+
typename TransposeShape,
|
| 46 |
+
typename Element
|
| 47 |
+
> struct Transpose;
|
| 48 |
+
|
| 49 |
+
/// Specialization for int8_t 4x4 transpose
|
| 50 |
+
template <int ElementCount_>
|
| 51 |
+
struct Transpose<ElementCount_, layout::PitchLinearShape<4,4> , int8_t> {
|
| 52 |
+
|
| 53 |
+
static const int kElementCount = ElementCount_;
|
| 54 |
+
using TransposeShape = layout::PitchLinearShape<4,4>;
|
| 55 |
+
using Element = int8_t;
|
| 56 |
+
using Fragment = cutlass::Array<Element, kElementCount>;
|
| 57 |
+
|
| 58 |
+
static_assert(!(kElementCount % TransposeShape::kCount), "Shape needs to be multiple of 16 elements to do a 4x4 transpose");
|
| 59 |
+
|
| 60 |
+
CUTLASS_DEVICE
|
| 61 |
+
void transform(Fragment& dst, Fragment& src) {
|
| 62 |
+
|
| 63 |
+
// Expose src/dst as int arrays.
|
| 64 |
+
int* src_int = reinterpret_cast<int*>(&src);
|
| 65 |
+
int* dst_int = reinterpret_cast<int*>(&dst);
|
| 66 |
+
|
| 67 |
+
CUTLASS_PRAGMA_UNROLL
|
| 68 |
+
for (int i = 0; i < kElementCount / TransposeShape::kCount; i++){
|
| 69 |
+
|
| 70 |
+
int const i0 = 4 * i + 0;
|
| 71 |
+
int const i1 = 4 * i + 1;
|
| 72 |
+
int const i2 = 4 * i + 2;
|
| 73 |
+
int const i3 = 4 * i + 3;
|
| 74 |
+
|
| 75 |
+
int a0 = src_int[i0];
|
| 76 |
+
int a1 = src_int[i1];
|
| 77 |
+
int a2 = src_int[i2];
|
| 78 |
+
int a3 = src_int[i3];
|
| 79 |
+
|
| 80 |
+
int b0, b1, b2, b3, c0;
|
| 81 |
+
b0 = __byte_perm(a0, a1, 0x0040);
|
| 82 |
+
c0 = __byte_perm(a2, a3, 0x0040);
|
| 83 |
+
b0 = __byte_perm(b0, c0, 0x5410);
|
| 84 |
+
|
| 85 |
+
b1 = __byte_perm(a0, a1, 0x0051);
|
| 86 |
+
c0 = __byte_perm(a2, a3, 0x0051);
|
| 87 |
+
b1 = __byte_perm(b1, c0, 0x5410);
|
| 88 |
+
|
| 89 |
+
b2 = __byte_perm(a0, a1, 0x0062);
|
| 90 |
+
c0 = __byte_perm(a2, a3, 0x0062);
|
| 91 |
+
b2 = __byte_perm(b2, c0, 0x5410);
|
| 92 |
+
|
| 93 |
+
b3 = __byte_perm(a0, a1, 0x0073);
|
| 94 |
+
c0 = __byte_perm(a2, a3, 0x0073);
|
| 95 |
+
b3 = __byte_perm(b3, c0, 0x5410);
|
| 96 |
+
|
| 97 |
+
dst_int[i0] = b0;
|
| 98 |
+
dst_int[i1] = b1;
|
| 99 |
+
dst_int[i2] = b2;
|
| 100 |
+
dst_int[i3] = b3;
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
} // namespace thread
|
| 106 |
+
} // namespace layout
|
| 107 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/thread/unary_op.h
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
#pragma once
|
| 32 |
+
|
| 33 |
+
#include "cutlass/cutlass.h"
|
| 34 |
+
#include "cutlass/complex.h"
|
| 35 |
+
|
| 36 |
+
namespace cutlass {
|
| 37 |
+
namespace transform {
|
| 38 |
+
namespace thread {
|
| 39 |
+
|
| 40 |
+
namespace UnaryTransform {
|
| 41 |
+
struct Identity; ///< None (i.e., identity)
|
| 42 |
+
struct Conjugate; ///< Complex conjugate
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
/// Element-wise unary operator that transforms one element of a fragment at a time
|
| 46 |
+
template<
|
| 47 |
+
typename FragmentIn, ///< Input Fragment
|
| 48 |
+
typename FragmentOut,///< Output Fragment
|
| 49 |
+
typename Transform> ///< Unary transform operator
|
| 50 |
+
class UnaryOp
|
| 51 |
+
{
|
| 52 |
+
public:
|
| 53 |
+
CUTLASS_DEVICE
|
| 54 |
+
static FragmentOut execute(FragmentIn &in)
|
| 55 |
+
{
|
| 56 |
+
static_assert(FragmentIn::kElements == FragmentOut::kElements, "Number of elements must match.");
|
| 57 |
+
static_assert(platform::is_same<Transform, UnaryTransform::Identity>::value ||
|
| 58 |
+
platform::is_same<Transform, UnaryTransform::Conjugate>::value,
|
| 59 |
+
"Unary Operator not supported.");
|
| 60 |
+
|
| 61 |
+
FragmentOut out;
|
| 62 |
+
if (platform::is_same<Transform, UnaryTransform::Identity>::value )
|
| 63 |
+
{
|
| 64 |
+
CUTLASS_PRAGMA_UNROLL
|
| 65 |
+
for (int i=0; i < FragmentIn::kElements; ++i){
|
| 66 |
+
out[i] = static_cast<typename FragmentOut::Element>(in[i]);
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
else if (platform::is_same<Transform, UnaryTransform::Conjugate>::value )
|
| 70 |
+
{
|
| 71 |
+
for (int i=0; i < FragmentIn::kElements; ++i){
|
| 72 |
+
out[i] = conj(static_cast<typename FragmentOut::Element>(in[i]));
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
return out;
|
| 76 |
+
}
|
| 77 |
+
};
|
| 78 |
+
|
| 79 |
+
template<typename FragmentIn, typename Transform>
|
| 80 |
+
class UnaryOp<FragmentIn, FragmentIn, Transform>
|
| 81 |
+
{
|
| 82 |
+
public:
|
| 83 |
+
CUTLASS_DEVICE
|
| 84 |
+
static FragmentIn execute(FragmentIn &in)
|
| 85 |
+
{
|
| 86 |
+
static_assert(platform::is_same<Transform, UnaryTransform::Identity>::value ||
|
| 87 |
+
platform::is_same<Transform, UnaryTransform::Conjugate>::value,
|
| 88 |
+
"Unary Operator not supported.");
|
| 89 |
+
|
| 90 |
+
if (platform::is_same<Transform, UnaryTransform::Identity>::value )
|
| 91 |
+
{
|
| 92 |
+
return in;
|
| 93 |
+
}
|
| 94 |
+
else if (platform::is_same<Transform, UnaryTransform::Conjugate>::value )
|
| 95 |
+
{
|
| 96 |
+
for(int i=0; i < FragmentIn::kElements; ++i){
|
| 97 |
+
in[i] = conj(in[i]);
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
return in;
|
| 101 |
+
}
|
| 102 |
+
};
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
}
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_iterator.h
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Ell iterator for matrix of indices (ellColInd matrix)
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
namespace cutlass {
|
| 38 |
+
namespace transform {
|
| 39 |
+
namespace threadblock {
|
| 40 |
+
|
| 41 |
+
namespace ell{
|
| 42 |
+
|
| 43 |
+
constexpr unsigned int SmemPow = 8;
|
| 44 |
+
constexpr unsigned int SmemStages = 2;
|
| 45 |
+
constexpr unsigned int SmemSize = 1 << SmemPow;
|
| 46 |
+
constexpr unsigned int SmemMask = (SmemSize*SmemStages-1);
|
| 47 |
+
|
| 48 |
+
class SharedStorage{
|
| 49 |
+
public:
|
| 50 |
+
Array<int, SmemSize*SmemStages> array;
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
class Iterator{
|
| 54 |
+
public:
|
| 55 |
+
using Layout = layout::PitchLinear;
|
| 56 |
+
using LongIndex = typename Layout::LongIndex;
|
| 57 |
+
|
| 58 |
+
private:
|
| 59 |
+
const int *gmem_col_idx_;
|
| 60 |
+
int *smem_col_idx_;
|
| 61 |
+
const int block_size_;
|
| 62 |
+
const int base_idx_;
|
| 63 |
+
const int k_shape_;
|
| 64 |
+
const int ell_increment_;
|
| 65 |
+
const int array_length_;
|
| 66 |
+
int col_idx_base_;
|
| 67 |
+
int residue_;
|
| 68 |
+
int counter_;
|
| 69 |
+
|
| 70 |
+
int pow2_;
|
| 71 |
+
int residue_shape_;
|
| 72 |
+
|
| 73 |
+
int smem_offset_;
|
| 74 |
+
int smem_stage_;
|
| 75 |
+
int gmem_offset_;
|
| 76 |
+
|
| 77 |
+
int lane_;
|
| 78 |
+
|
| 79 |
+
bool is_pow2_;
|
| 80 |
+
bool is_residue_tile_;
|
| 81 |
+
|
| 82 |
+
public:
|
| 83 |
+
CUTLASS_DEVICE
|
| 84 |
+
void load_ell_indices(){
|
| 85 |
+
for(int i=threadIdx.x; i<SmemSize; i+=blockDim.x){
|
| 86 |
+
int idx = (gmem_offset_+i < array_length_) ? gmem_offset_+i : array_length_-1;
|
| 87 |
+
int gmem_col_idx = gmem_col_idx_[idx] - base_idx_;
|
| 88 |
+
smem_col_idx_[i + smem_stage_ * SmemSize] =
|
| 89 |
+
(gmem_col_idx >= 0) ? gmem_col_idx : -1;
|
| 90 |
+
}
|
| 91 |
+
gmem_offset_ += SmemSize;
|
| 92 |
+
smem_stage_ ^= 1;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
CUTLASS_DEVICE
|
| 96 |
+
Iterator(
|
| 97 |
+
SharedStorage& shared_storage_base,
|
| 98 |
+
const int* col_idx,
|
| 99 |
+
const int& block_size,
|
| 100 |
+
const int& base_idx,
|
| 101 |
+
const int k_shape,
|
| 102 |
+
const int& problem_size_k,
|
| 103 |
+
const int& ell_stride,
|
| 104 |
+
const int& thread_idx)
|
| 105 |
+
: residue_(0),
|
| 106 |
+
counter_(0),
|
| 107 |
+
smem_offset_(0),
|
| 108 |
+
smem_stage_(0),
|
| 109 |
+
gmem_offset_(0),
|
| 110 |
+
block_size_(block_size),
|
| 111 |
+
base_idx_(base_idx),
|
| 112 |
+
k_shape_(k_shape),
|
| 113 |
+
ell_increment_(ell_stride * block_size),
|
| 114 |
+
array_length_((problem_size_k + block_size_ - 1) / block_size_),
|
| 115 |
+
residue_shape_(problem_size_k % k_shape_),
|
| 116 |
+
is_residue_tile_(residue_shape_ != 0),
|
| 117 |
+
smem_col_idx_(reinterpret_cast<int*>(&shared_storage_base.array)),
|
| 118 |
+
gmem_col_idx_(const_cast<int*>(col_idx)),
|
| 119 |
+
lane_(thread_idx % 32) {
|
| 120 |
+
|
| 121 |
+
load_ell_indices();
|
| 122 |
+
__syncthreads();
|
| 123 |
+
|
| 124 |
+
is_pow2_ = ((block_size_ & (block_size_ - 1)) == 0);
|
| 125 |
+
if( is_pow2_ && k_shape <= block_size_ ) lane_ = 0;
|
| 126 |
+
|
| 127 |
+
col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_;
|
| 128 |
+
|
| 129 |
+
pow2_ = 0;
|
| 130 |
+
while(block_size_ >> (pow2_ + 1)) ++pow2_;
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
CUTLASS_DEVICE
|
| 134 |
+
int get_blocksize(){
|
| 135 |
+
return block_size_;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
CUTLASS_DEVICE
|
| 139 |
+
Iterator &operator++(){
|
| 140 |
+
if(is_residue_tile_){
|
| 141 |
+
residue_ += residue_shape_;
|
| 142 |
+
is_residue_tile_ = false;
|
| 143 |
+
} else {
|
| 144 |
+
residue_ += k_shape_;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
if(residue_ < block_size_){
|
| 148 |
+
return *this;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
if((array_length_ > SmemSize) && (((smem_offset_ >> SmemPow) & 1) != smem_stage_))
|
| 152 |
+
load_ell_indices();
|
| 153 |
+
|
| 154 |
+
if(residue_ == block_size_){
|
| 155 |
+
++smem_offset_;
|
| 156 |
+
counter_ += ell_increment_;
|
| 157 |
+
residue_ = 0;
|
| 158 |
+
col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_ - counter_;
|
| 159 |
+
return *this;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
if(is_pow2_){
|
| 163 |
+
smem_offset_ += residue_ >> pow2_;
|
| 164 |
+
counter_ += (residue_ >> pow2_) * ell_increment_;
|
| 165 |
+
residue_ = residue_ & ((1 << pow2_) - 1);
|
| 166 |
+
}
|
| 167 |
+
else {
|
| 168 |
+
smem_offset_ += residue_ / block_size_;
|
| 169 |
+
counter_ += (residue_ / block_size_) * ell_increment_;
|
| 170 |
+
residue_ %= block_size_;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_ - counter_;
|
| 174 |
+
|
| 175 |
+
return *this;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
CUTLASS_DEVICE
|
| 179 |
+
LongIndex get_offset(const int& idx) {
|
| 180 |
+
int num_jump_tiles;
|
| 181 |
+
if(is_pow2_)
|
| 182 |
+
num_jump_tiles = (idx + residue_) >> pow2_;
|
| 183 |
+
else
|
| 184 |
+
num_jump_tiles = (idx + residue_) / block_size_;
|
| 185 |
+
|
| 186 |
+
int tmp = __shfl_sync(0xffffffff, col_idx_base_, num_jump_tiles);
|
| 187 |
+
return tmp - num_jump_tiles * ell_increment_;
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
CUTLASS_DEVICE
|
| 191 |
+
LongIndex get_offset_fast() {
|
| 192 |
+
return col_idx_base_;
|
| 193 |
+
}
|
| 194 |
+
};
|
| 195 |
+
|
| 196 |
+
}
|
| 197 |
+
}
|
| 198 |
+
}
|
| 199 |
+
}
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h
ADDED
|
@@ -0,0 +1,1350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Ell iterator for Blocked-Ell matrix (ellValue matrix) used with EllMmaMultistage
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/array.h"
|
| 38 |
+
#include "cutlass/coord.h"
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
#include "cutlass/layout/matrix.h"
|
| 41 |
+
#include "cutlass/layout/pitch_linear.h"
|
| 42 |
+
#include "cutlass/matrix_shape.h"
|
| 43 |
+
#include "cutlass/predicate_vector.h"
|
| 44 |
+
#include "cutlass/tensor_ref.h"
|
| 45 |
+
#include "cutlass/tensor_view.h"
|
| 46 |
+
|
| 47 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 50 |
+
|
| 51 |
+
namespace cutlass {
|
| 52 |
+
namespace transform {
|
| 53 |
+
namespace threadblock {
|
| 54 |
+
|
| 55 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 56 |
+
|
| 57 |
+
/// EllPredicatedTileAccessIterator
|
| 58 |
+
///
|
| 59 |
+
template <typename Shape, typename Element, typename Layout, int AdvanceRank,
|
| 60 |
+
typename ThreadMap, typename AccessType>
|
| 61 |
+
class EllPredicatedTileAccessIterator;
|
| 62 |
+
|
| 63 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 64 |
+
|
| 65 |
+
/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data.
|
| 66 |
+
///
|
| 67 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 68 |
+
typename ThreadMap_, typename AccessType_>
|
| 69 |
+
class EllPredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
|
| 70 |
+
AdvanceRank, ThreadMap_, AccessType_> {
|
| 71 |
+
public:
|
| 72 |
+
static_assert(
|
| 73 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 74 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 75 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 76 |
+
|
| 77 |
+
using Shape = Shape_;
|
| 78 |
+
using Element = Element_;
|
| 79 |
+
using Layout = layout::PitchLinear;
|
| 80 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 81 |
+
using ThreadMap = ThreadMap_;
|
| 82 |
+
using AccessType = AccessType_;
|
| 83 |
+
|
| 84 |
+
using Index = typename Layout::Index;
|
| 85 |
+
using LongIndex = typename Layout::LongIndex;
|
| 86 |
+
|
| 87 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 88 |
+
using TensorView = TensorView<Element, Layout>;
|
| 89 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 90 |
+
|
| 91 |
+
using Pointer = Element *;
|
| 92 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 93 |
+
|
| 94 |
+
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
| 95 |
+
|
| 96 |
+
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
| 97 |
+
"Vectors implied by the thread map must be divisible by the access type.");
|
| 98 |
+
|
| 99 |
+
static int const kPredicatesPerByte = 4;
|
| 100 |
+
static int const kPredicatesPerWord = 4 * kPredicatesPerByte;
|
| 101 |
+
|
| 102 |
+
static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector;
|
| 103 |
+
|
| 104 |
+
/// Number of 32b words containing predicates
|
| 105 |
+
static int const kPredicateByteCount =
|
| 106 |
+
(kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte;
|
| 107 |
+
static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4;
|
| 108 |
+
|
| 109 |
+
static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u;
|
| 110 |
+
|
| 111 |
+
static_assert(kPredicateWordCount <= 4, "Too many predicates.");
|
| 112 |
+
|
| 113 |
+
/// Predicate vector stores mask to guard accesses
|
| 114 |
+
using Mask = Array<uint32_t, kPredicateWordCount>;
|
| 115 |
+
|
| 116 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 117 |
+
class Params {
|
| 118 |
+
public:
|
| 119 |
+
friend EllPredicatedTileAccessIterator;
|
| 120 |
+
|
| 121 |
+
private:
|
| 122 |
+
/// stride of pitch-linear layout (units of Element)
|
| 123 |
+
LongIndex stride_;
|
| 124 |
+
/// amount (in byte) to increment pointer to move to next access along
|
| 125 |
+
/// strided dimension
|
| 126 |
+
LongIndex inc_strided_;
|
| 127 |
+
/// amount (in byte) to increment pointer from last access to first access
|
| 128 |
+
/// of next tile
|
| 129 |
+
LongIndex inc_next_;
|
| 130 |
+
/// amount (in byte) to increment pointer from first access of current tile
|
| 131 |
+
/// to first access of next tile
|
| 132 |
+
LongIndex inc_advance_;
|
| 133 |
+
|
| 134 |
+
public:
|
| 135 |
+
|
| 136 |
+
// Default ctor
|
| 137 |
+
CUTLASS_HOST_DEVICE
|
| 138 |
+
Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { }
|
| 139 |
+
|
| 140 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 141 |
+
CUTLASS_HOST_DEVICE
|
| 142 |
+
Params(Layout const &layout) : stride_(layout.stride(0)) {
|
| 143 |
+
inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) *
|
| 144 |
+
sizeof_bits<Element>::value / 8;
|
| 145 |
+
|
| 146 |
+
if (kAdvanceRank) {
|
| 147 |
+
// advance along strided dimension
|
| 148 |
+
inc_advance_ =
|
| 149 |
+
Shape::kStrided * LongIndex(stride_) * sizeof_bits<Element>::value / 8;
|
| 150 |
+
} else {
|
| 151 |
+
// advance along contiguous dimension
|
| 152 |
+
inc_advance_ = Shape::kContiguous * sizeof_bits<Element>::value / 8;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) *
|
| 156 |
+
ThreadMap::Delta::kStrided * LongIndex(stride_) *
|
| 157 |
+
sizeof_bits<Element>::value / 8;
|
| 158 |
+
};
|
| 159 |
+
};
|
| 160 |
+
|
| 161 |
+
private:
|
| 162 |
+
/// Internal pointer type permits fast address arithmetic
|
| 163 |
+
using BytePointer = char *;
|
| 164 |
+
|
| 165 |
+
private:
|
| 166 |
+
//
|
| 167 |
+
// Data members
|
| 168 |
+
//
|
| 169 |
+
|
| 170 |
+
/// Parameters object with precomputed internal state
|
| 171 |
+
Params const ¶ms_;
|
| 172 |
+
|
| 173 |
+
/// Internal pointer to first access of tile
|
| 174 |
+
BytePointer pointer_;
|
| 175 |
+
|
| 176 |
+
/// Guard predicates
|
| 177 |
+
uint32_t predicates_[kPredicateWordCount];
|
| 178 |
+
|
| 179 |
+
/// Size of tensor
|
| 180 |
+
TensorCoord extent_;
|
| 181 |
+
|
| 182 |
+
/// Initial offset for each thread
|
| 183 |
+
TensorCoord thread_offset_;
|
| 184 |
+
|
| 185 |
+
/// Offset to the first steady-state tile
|
| 186 |
+
TensorCoord residue_offset_;
|
| 187 |
+
|
| 188 |
+
/// Initial offset to define ELL block
|
| 189 |
+
TensorCoord ell_offset_;
|
| 190 |
+
|
| 191 |
+
/// Used for out-of-order visitation
|
| 192 |
+
bool is_residue_tile_;
|
| 193 |
+
|
| 194 |
+
/// Iteration along vectors implied by the thread map
|
| 195 |
+
int iteration_vector_;
|
| 196 |
+
|
| 197 |
+
/// Iteration in the contiguous dimension
|
| 198 |
+
int iteration_contiguous_;
|
| 199 |
+
|
| 200 |
+
/// Iteration in the strided dimension
|
| 201 |
+
int iteration_strided_;
|
| 202 |
+
|
| 203 |
+
public:
|
| 204 |
+
/// Computes predicates based on internally tracked per-thread offset.
|
| 205 |
+
CUTLASS_DEVICE
|
| 206 |
+
void compute_predicates_(
|
| 207 |
+
/// Extent of the matrix window
|
| 208 |
+
TensorCoord extent,
|
| 209 |
+
/// optionally, simplify predicate calculation during 'steady state' phase
|
| 210 |
+
bool is_steady_state = false) {
|
| 211 |
+
|
| 212 |
+
CUTLASS_PRAGMA_UNROLL
|
| 213 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 214 |
+
predicates_[i] = 0u;
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
CUTLASS_PRAGMA_UNROLL
|
| 218 |
+
for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) {
|
| 219 |
+
|
| 220 |
+
int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
|
| 221 |
+
|
| 222 |
+
int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
|
| 223 |
+
|
| 224 |
+
int c = access_residual / kAccessesPerVector;
|
| 225 |
+
int v = access_residual % kAccessesPerVector;
|
| 226 |
+
|
| 227 |
+
TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements,
|
| 228 |
+
s * ThreadMap::Delta::kStrided);
|
| 229 |
+
|
| 230 |
+
TensorCoord coord = thread_offset_ + iteration_coord;
|
| 231 |
+
|
| 232 |
+
bool guard;
|
| 233 |
+
|
| 234 |
+
if (is_steady_state) {
|
| 235 |
+
if (kAdvanceRank == 0) {
|
| 236 |
+
guard = (coord.strided() < extent.strided());
|
| 237 |
+
} else {
|
| 238 |
+
guard = (coord.contiguous() < extent.contiguous());
|
| 239 |
+
}
|
| 240 |
+
} else {
|
| 241 |
+
guard = (coord.strided() < extent.strided() &&
|
| 242 |
+
coord.contiguous() < extent.contiguous());
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s);
|
| 246 |
+
|
| 247 |
+
int word_idx = pred_idx / kPredicatesPerWord;
|
| 248 |
+
int residual = pred_idx % kPredicatesPerWord;
|
| 249 |
+
int byte_idx = residual / kPredicatesPerByte;
|
| 250 |
+
int bit_idx = residual % kPredicatesPerByte;
|
| 251 |
+
|
| 252 |
+
predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx));
|
| 253 |
+
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
public:
|
| 259 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 260 |
+
/// and thread ID
|
| 261 |
+
CUTLASS_HOST_DEVICE
|
| 262 |
+
EllPredicatedTileAccessIterator(
|
| 263 |
+
/// Precomputed parameters object
|
| 264 |
+
Params const ¶ms,
|
| 265 |
+
/// Pointer to start of tensor
|
| 266 |
+
Pointer pointer,
|
| 267 |
+
/// Extent of tensor
|
| 268 |
+
TensorCoord extent,
|
| 269 |
+
/// ID of each participating thread
|
| 270 |
+
int thread_id,
|
| 271 |
+
/// Initial offset of threadblock
|
| 272 |
+
TensorCoord const &threadblock_offset)
|
| 273 |
+
: params_(params),
|
| 274 |
+
pointer_(reinterpret_cast<BytePointer>(
|
| 275 |
+
const_cast<NonConstPointer>(pointer))),
|
| 276 |
+
extent_(extent),
|
| 277 |
+
is_residue_tile_(true) {
|
| 278 |
+
|
| 279 |
+
TensorCoord residue_extent;
|
| 280 |
+
if (kAdvanceRank) {
|
| 281 |
+
|
| 282 |
+
typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided;
|
| 283 |
+
if (!residue_size) {
|
| 284 |
+
residue_size = Shape::kStrided;
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
residue_offset_ = make_Coord(0, residue_size);
|
| 288 |
+
residue_extent = make_Coord(
|
| 289 |
+
extent_.contiguous(),
|
| 290 |
+
min(threadblock_offset.strided() + residue_size, extent_.strided())
|
| 291 |
+
);
|
| 292 |
+
} else {
|
| 293 |
+
|
| 294 |
+
typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous;
|
| 295 |
+
if (!residue_size) {
|
| 296 |
+
residue_size = Shape::kContiguous;
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
residue_offset_ = make_Coord(residue_size, 0);
|
| 300 |
+
|
| 301 |
+
residue_extent = make_Coord(
|
| 302 |
+
min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size),
|
| 303 |
+
extent_.strided()
|
| 304 |
+
);
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
// Per-thread offset in logical coordinates of tensor
|
| 308 |
+
ell_offset_ = ThreadMap::initial_offset(thread_id);
|
| 309 |
+
thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id);
|
| 310 |
+
|
| 311 |
+
// update internal pointers
|
| 312 |
+
Layout layout(params_.stride_);
|
| 313 |
+
add_pointer_offset(layout(thread_offset_));
|
| 314 |
+
|
| 315 |
+
compute_predicates_(residue_extent, false);
|
| 316 |
+
|
| 317 |
+
set_iteration_index(0);
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
/// Construct a EllPredicatedTileAccessIterator with zero threadblock offset
|
| 321 |
+
CUTLASS_HOST_DEVICE
|
| 322 |
+
EllPredicatedTileAccessIterator(
|
| 323 |
+
/// Precomputed parameters object
|
| 324 |
+
Params const ¶ms,
|
| 325 |
+
/// Pointer to start of tensor
|
| 326 |
+
Pointer pointer,
|
| 327 |
+
/// Extent of tensor
|
| 328 |
+
TensorCoord extent,
|
| 329 |
+
///< ID of each participating thread
|
| 330 |
+
int thread_id)
|
| 331 |
+
: EllPredicatedTileAccessIterator(params, pointer, extent, thread_id,
|
| 332 |
+
make_Coord(0, 0)) {}
|
| 333 |
+
|
| 334 |
+
/// Overrides the internal iteration index
|
| 335 |
+
CUTLASS_HOST_DEVICE
|
| 336 |
+
void set_iteration_index(int index) {
|
| 337 |
+
|
| 338 |
+
iteration_vector_ = index % kAccessesPerVector;
|
| 339 |
+
int residual_access = index / kAccessesPerVector;
|
| 340 |
+
|
| 341 |
+
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
| 342 |
+
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
| 343 |
+
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
/// Adds a pointer offset in units of Element
|
| 347 |
+
CUTLASS_HOST_DEVICE
|
| 348 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 349 |
+
pointer_ += sizeof_bits<Element>::value * pointer_offset / 8;
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
|
| 353 |
+
CUTLASS_DEVICE
|
| 354 |
+
void add_tile_offset(
|
| 355 |
+
TensorCoord const &tile_offset) {
|
| 356 |
+
if (is_residue_tile_) {
|
| 357 |
+
|
| 358 |
+
thread_offset_ += residue_offset_;
|
| 359 |
+
|
| 360 |
+
Layout layout(params_.stride_);
|
| 361 |
+
add_pointer_offset(layout(residue_offset_));
|
| 362 |
+
|
| 363 |
+
compute_predicates_(extent_, true);
|
| 364 |
+
|
| 365 |
+
if (kAdvanceRank) {
|
| 366 |
+
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1);
|
| 367 |
+
pointer_ += Shape::kContiguous * tile_offset.contiguous();
|
| 368 |
+
} else {
|
| 369 |
+
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1);
|
| 370 |
+
pointer_ += Shape::kStrided * tile_offset.strided();
|
| 371 |
+
}
|
| 372 |
+
} else {
|
| 373 |
+
if (kAdvanceRank) {
|
| 374 |
+
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided());
|
| 375 |
+
pointer_ += Shape::kContiguous * tile_offset.contiguous();
|
| 376 |
+
} else {
|
| 377 |
+
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous());
|
| 378 |
+
pointer_ += Shape::kStrided * tile_offset.strided();
|
| 379 |
+
}
|
| 380 |
+
}
|
| 381 |
+
is_residue_tile_ = false;
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
/// Returns a pointer
|
| 385 |
+
CUTLASS_HOST_DEVICE
|
| 386 |
+
AccessType *get() const {
|
| 387 |
+
return reinterpret_cast<AccessType *>(
|
| 388 |
+
pointer_ +
|
| 389 |
+
iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value) / 8) + iteration_vector_;
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
/// Returns a k_location
|
| 393 |
+
CUTLASS_HOST_DEVICE
|
| 394 |
+
int get_k() const {
|
| 395 |
+
if(kAdvanceRank){ //strided
|
| 396 |
+
return ell_offset_.strided() + iteration_strided_ * ThreadMap::Delta::kStrided;
|
| 397 |
+
}else{
|
| 398 |
+
return ell_offset_.contiguous() + iteration_contiguous_ * ThreadMap::Delta::kContiguous + iteration_vector_ * AccessType::kElements;
|
| 399 |
+
}
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
CUTLASS_HOST_DEVICE
|
| 403 |
+
int get_stride() const {
|
| 404 |
+
if(kAdvanceRank)
|
| 405 |
+
return params_.stride_;
|
| 406 |
+
else
|
| 407 |
+
return 1;
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
/// Increment and return an instance to self.
|
| 411 |
+
CUTLASS_HOST_DEVICE
|
| 412 |
+
EllPredicatedTileAccessIterator &operator++() {
|
| 413 |
+
|
| 414 |
+
++iteration_vector_;
|
| 415 |
+
if (iteration_vector_ < kAccessesPerVector) {
|
| 416 |
+
return *this;
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
iteration_vector_ = 0;
|
| 420 |
+
++iteration_contiguous_;
|
| 421 |
+
|
| 422 |
+
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
| 423 |
+
return *this;
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
// Enter here only if (iteration_contiguous_ ==
|
| 427 |
+
// ThreadMap::Iteration::kContiguous)
|
| 428 |
+
iteration_contiguous_ = 0;
|
| 429 |
+
++iteration_strided_;
|
| 430 |
+
|
| 431 |
+
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
| 432 |
+
pointer_ += params_.inc_strided_;
|
| 433 |
+
return *this;
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
// Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
|
| 437 |
+
// which means we enter the next tile.
|
| 438 |
+
iteration_strided_ = 0;
|
| 439 |
+
|
| 440 |
+
// advance to next tile
|
| 441 |
+
pointer_ += params_.inc_next_;
|
| 442 |
+
|
| 443 |
+
// now return to start tile - if the iterator is subsequently advanced, this
|
| 444 |
+
// subtraction as well as the subsequent integer addition are both elided by
|
| 445 |
+
// the compiler.
|
| 446 |
+
pointer_ -= params_.inc_advance_;
|
| 447 |
+
|
| 448 |
+
return *this;
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
/// Increment and return an instance to self.
|
| 452 |
+
CUTLASS_HOST_DEVICE
|
| 453 |
+
EllPredicatedTileAccessIterator operator++(int) {
|
| 454 |
+
EllPredicatedTileAccessIterator self(*this);
|
| 455 |
+
operator++();
|
| 456 |
+
return self;
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
/// Clears the predicate set efficiently
|
| 460 |
+
CUTLASS_HOST_DEVICE
|
| 461 |
+
void clear_mask(bool enable = true) {
|
| 462 |
+
CUTLASS_PRAGMA_UNROLL
|
| 463 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 464 |
+
predicates_[i] = enable ? 0u : predicates_[i];
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
/// Clears the predicate set efficiently
|
| 470 |
+
CUTLASS_HOST_DEVICE
|
| 471 |
+
void enable_mask() {
|
| 472 |
+
CUTLASS_PRAGMA_UNROLL
|
| 473 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 474 |
+
predicates_[i] = 0xffffffff;
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 480 |
+
CUTLASS_HOST_DEVICE
|
| 481 |
+
void set_mask(Mask const &mask) {
|
| 482 |
+
CUTLASS_PRAGMA_UNROLL
|
| 483 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 484 |
+
predicates_[i] = mask[i];
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
/// Gets the mask
|
| 490 |
+
CUTLASS_HOST_DEVICE
|
| 491 |
+
void get_mask(Mask &mask) {
|
| 492 |
+
CUTLASS_PRAGMA_UNROLL
|
| 493 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 494 |
+
mask[i] = predicates_[i];
|
| 495 |
+
}
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
/// add mask for small tiles in ELL
|
| 499 |
+
CUTLASS_DEVICE
|
| 500 |
+
void ell_add_mask(int blocksize) {
|
| 501 |
+
|
| 502 |
+
Mask mask;
|
| 503 |
+
|
| 504 |
+
CUTLASS_PRAGMA_UNROLL
|
| 505 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 506 |
+
mask[i] = 0u;
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
CUTLASS_PRAGMA_UNROLL
|
| 510 |
+
for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) {
|
| 511 |
+
|
| 512 |
+
int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
|
| 513 |
+
|
| 514 |
+
int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
|
| 515 |
+
|
| 516 |
+
int c = access_residual / kAccessesPerVector;
|
| 517 |
+
int v = access_residual % kAccessesPerVector;
|
| 518 |
+
|
| 519 |
+
TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements,
|
| 520 |
+
s * ThreadMap::Delta::kStrided);
|
| 521 |
+
|
| 522 |
+
TensorCoord coord = ell_offset_ + iteration_coord;
|
| 523 |
+
|
| 524 |
+
bool guard;
|
| 525 |
+
|
| 526 |
+
if (kAdvanceRank == 0) {
|
| 527 |
+
guard = (coord.strided() < blocksize);
|
| 528 |
+
} else {
|
| 529 |
+
guard = (coord.contiguous() < blocksize);
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s);
|
| 533 |
+
|
| 534 |
+
int word_idx = pred_idx / kPredicatesPerWord;
|
| 535 |
+
int residual = pred_idx % kPredicatesPerWord;
|
| 536 |
+
int byte_idx = residual / kPredicatesPerByte;
|
| 537 |
+
int bit_idx = residual % kPredicatesPerByte;
|
| 538 |
+
|
| 539 |
+
mask[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx));
|
| 540 |
+
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
CUTLASS_PRAGMA_UNROLL
|
| 544 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 545 |
+
mask[i] &= predicates_[i];
|
| 546 |
+
}
|
| 547 |
+
set_mask(mask);
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
/// Returns whether access is valid or not
|
| 551 |
+
CUTLASS_HOST_DEVICE
|
| 552 |
+
bool valid() {
|
| 553 |
+
|
| 554 |
+
int pred_idx =
|
| 555 |
+
iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous);
|
| 556 |
+
|
| 557 |
+
int word_idx = pred_idx / kPredicatesPerWord;
|
| 558 |
+
int residual = pred_idx % kPredicatesPerWord;
|
| 559 |
+
int byte_idx = residual / kPredicatesPerByte;
|
| 560 |
+
int bit_idx = residual % kPredicatesPerByte;
|
| 561 |
+
|
| 562 |
+
bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0;
|
| 563 |
+
return pred;
|
| 564 |
+
|
| 565 |
+
}
|
| 566 |
+
};
|
| 567 |
+
|
| 568 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 569 |
+
|
| 570 |
+
/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data.
|
| 571 |
+
///
|
| 572 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 573 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 574 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 575 |
+
/// MaskedTileIteratorConcept
|
| 576 |
+
///
|
| 577 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 578 |
+
typename ThreadMap_, typename AccessType_>
|
| 579 |
+
class EllPredicatedTileAccessIterator<Shape_, Element_, layout::ColumnMajor,
|
| 580 |
+
AdvanceRank, ThreadMap_, AccessType_> {
|
| 581 |
+
public:
|
| 582 |
+
static_assert(
|
| 583 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 584 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 585 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 586 |
+
|
| 587 |
+
using Shape = Shape_;
|
| 588 |
+
using Element = Element_;
|
| 589 |
+
using Layout = layout::ColumnMajor;
|
| 590 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 591 |
+
using ThreadMap = ThreadMap_;
|
| 592 |
+
using AccessType = AccessType_;
|
| 593 |
+
|
| 594 |
+
using Index = typename Layout::Index;
|
| 595 |
+
using LongIndex = typename Layout::LongIndex;
|
| 596 |
+
|
| 597 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 598 |
+
using TensorView = TensorView<Element, Layout>;
|
| 599 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 600 |
+
|
| 601 |
+
using Pointer = Element *;
|
| 602 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 603 |
+
|
| 604 |
+
using UnderlyingIterator = EllPredicatedTileAccessIterator<
|
| 605 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
|
| 606 |
+
layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>;
|
| 607 |
+
|
| 608 |
+
/// Predicate vector stores mask to guard accesses
|
| 609 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 610 |
+
|
| 611 |
+
static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
|
| 612 |
+
|
| 613 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 614 |
+
class Params {
|
| 615 |
+
private:
|
| 616 |
+
friend EllPredicatedTileAccessIterator;
|
| 617 |
+
|
| 618 |
+
/// Parameters object
|
| 619 |
+
typename UnderlyingIterator::Params params_;
|
| 620 |
+
|
| 621 |
+
public:
|
| 622 |
+
|
| 623 |
+
/// Default ctor
|
| 624 |
+
CUTLASS_HOST_DEVICE
|
| 625 |
+
Params() { }
|
| 626 |
+
|
| 627 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 628 |
+
CUTLASS_HOST_DEVICE
|
| 629 |
+
Params(Layout const &layout)
|
| 630 |
+
: params_(layout::PitchLinear(layout.stride(0))){};
|
| 631 |
+
};
|
| 632 |
+
|
| 633 |
+
private:
|
| 634 |
+
//
|
| 635 |
+
// Data members
|
| 636 |
+
//
|
| 637 |
+
|
| 638 |
+
/// Underlying pitch-linear tile iterator
|
| 639 |
+
UnderlyingIterator iterator_;
|
| 640 |
+
|
| 641 |
+
public:
|
| 642 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 643 |
+
/// and thread ID
|
| 644 |
+
CUTLASS_HOST_DEVICE
|
| 645 |
+
EllPredicatedTileAccessIterator(
|
| 646 |
+
///< Precomputed parameters object
|
| 647 |
+
Params const ¶ms,
|
| 648 |
+
///< Pointer to start of tensor
|
| 649 |
+
Pointer pointer,
|
| 650 |
+
///< Extent of tensor
|
| 651 |
+
TensorCoord extent,
|
| 652 |
+
///< ID of each participating thread
|
| 653 |
+
int thread_id,
|
| 654 |
+
///< Initial offset of threadblock
|
| 655 |
+
TensorCoord const &threadblock_offset)
|
| 656 |
+
: iterator_(params.params_, pointer,
|
| 657 |
+
layout::PitchLinearCoord(extent.row(), extent.column()),
|
| 658 |
+
thread_id,
|
| 659 |
+
layout::PitchLinearCoord(threadblock_offset.row(),
|
| 660 |
+
threadblock_offset.column())) {}
|
| 661 |
+
|
| 662 |
+
/// Construct a EllPredicatedTileAccessIterator with zero threadblock offset
|
| 663 |
+
CUTLASS_HOST_DEVICE
|
| 664 |
+
EllPredicatedTileAccessIterator(
|
| 665 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 666 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 667 |
+
TensorCoord extent, ///< Extent of tensor
|
| 668 |
+
int thread_id ///< ID of each participating thread
|
| 669 |
+
)
|
| 670 |
+
: EllPredicatedTileAccessIterator(params, pointer, extent, thread_id,
|
| 671 |
+
make_Coord(0, 0)) {}
|
| 672 |
+
|
| 673 |
+
/// Overrides the internal iteration index
|
| 674 |
+
CUTLASS_HOST_DEVICE
|
| 675 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 676 |
+
|
| 677 |
+
/// Adds a pointer offset in units of Element
|
| 678 |
+
CUTLASS_HOST_DEVICE
|
| 679 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 680 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 681 |
+
}
|
| 682 |
+
|
| 683 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 684 |
+
/// tiles
|
| 685 |
+
CUTLASS_HOST_DEVICE
|
| 686 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 687 |
+
iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
|
| 688 |
+
}
|
| 689 |
+
|
| 690 |
+
/// Returns a pointer
|
| 691 |
+
CUTLASS_HOST_DEVICE
|
| 692 |
+
AccessType *get() const {
|
| 693 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 694 |
+
}
|
| 695 |
+
|
| 696 |
+
CUTLASS_HOST_DEVICE
|
| 697 |
+
int get_k() const {
|
| 698 |
+
return iterator_.get_k();
|
| 699 |
+
}
|
| 700 |
+
|
| 701 |
+
CUTLASS_HOST_DEVICE
|
| 702 |
+
int get_stride() const {
|
| 703 |
+
return iterator_.get_stride();
|
| 704 |
+
}
|
| 705 |
+
|
| 706 |
+
/// Advances to the next tile in memory.
|
| 707 |
+
///
|
| 708 |
+
/// The first time this method is called, predicates are updated, and the
|
| 709 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 710 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 711 |
+
/// pointer.
|
| 712 |
+
CUTLASS_HOST_DEVICE
|
| 713 |
+
EllPredicatedTileAccessIterator &operator++() {
|
| 714 |
+
++iterator_;
|
| 715 |
+
return *this;
|
| 716 |
+
}
|
| 717 |
+
|
| 718 |
+
/// Advances to the next tile in memory.
|
| 719 |
+
///
|
| 720 |
+
/// The first time this method is called, predicates are updated, and the
|
| 721 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 722 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 723 |
+
/// pointer.
|
| 724 |
+
CUTLASS_HOST_DEVICE
|
| 725 |
+
EllPredicatedTileAccessIterator operator++(int) {
|
| 726 |
+
EllPredicatedTileAccessIterator self(*this);
|
| 727 |
+
operator++();
|
| 728 |
+
return self;
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
/// Clears the predicate set efficiently
|
| 732 |
+
CUTLASS_HOST_DEVICE
|
| 733 |
+
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
| 734 |
+
|
| 735 |
+
/// Clears the predicate set efficiently
|
| 736 |
+
CUTLASS_HOST_DEVICE
|
| 737 |
+
void enable_mask() { iterator_.enable_mask(); }
|
| 738 |
+
|
| 739 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 740 |
+
CUTLASS_HOST_DEVICE
|
| 741 |
+
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
| 742 |
+
|
| 743 |
+
/// Gets the mask
|
| 744 |
+
CUTLASS_HOST_DEVICE
|
| 745 |
+
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
| 746 |
+
|
| 747 |
+
/// add mask for small tiles in ELL
|
| 748 |
+
CUTLASS_DEVICE
|
| 749 |
+
void ell_add_mask(int blocksize) {
|
| 750 |
+
iterator_.ell_add_mask(blocksize);
|
| 751 |
+
}
|
| 752 |
+
|
| 753 |
+
/// Returns whether access is valid or not
|
| 754 |
+
CUTLASS_HOST_DEVICE
|
| 755 |
+
bool valid() {
|
| 756 |
+
return iterator_.valid();
|
| 757 |
+
}
|
| 758 |
+
};
|
| 759 |
+
|
| 760 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 761 |
+
|
| 762 |
+
/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data.
|
| 763 |
+
///
|
| 764 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 765 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 766 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 767 |
+
/// MaskedTileIteratorConcept
|
| 768 |
+
///
|
| 769 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 770 |
+
typename ThreadMap_, typename AccessType_>
|
| 771 |
+
class EllPredicatedTileAccessIterator<Shape_, Element_, layout::RowMajor,
|
| 772 |
+
AdvanceRank, ThreadMap_, AccessType_> {
|
| 773 |
+
public:
|
| 774 |
+
static_assert(
|
| 775 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 776 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 777 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 778 |
+
|
| 779 |
+
using Shape = Shape_;
|
| 780 |
+
using Element = Element_;
|
| 781 |
+
using Layout = layout::RowMajor;
|
| 782 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 783 |
+
using ThreadMap = ThreadMap_;
|
| 784 |
+
using AccessType = AccessType_;
|
| 785 |
+
|
| 786 |
+
using Index = typename Layout::Index;
|
| 787 |
+
using LongIndex = typename Layout::LongIndex;
|
| 788 |
+
|
| 789 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 790 |
+
using TensorView = TensorView<Element, Layout>;
|
| 791 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 792 |
+
|
| 793 |
+
using Pointer = Element *;
|
| 794 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 795 |
+
|
| 796 |
+
using UnderlyingIterator = EllPredicatedTileAccessIterator<
|
| 797 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
|
| 798 |
+
layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>;
|
| 799 |
+
|
| 800 |
+
static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
|
| 801 |
+
|
| 802 |
+
/// Predicate vector stores mask to guard accesses
|
| 803 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 804 |
+
|
| 805 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 806 |
+
class Params {
|
| 807 |
+
private:
|
| 808 |
+
friend EllPredicatedTileAccessIterator;
|
| 809 |
+
|
| 810 |
+
/// Parameters object
|
| 811 |
+
typename UnderlyingIterator::Params params_;
|
| 812 |
+
|
| 813 |
+
public:
|
| 814 |
+
|
| 815 |
+
/// Default ctor
|
| 816 |
+
CUTLASS_HOST_DEVICE
|
| 817 |
+
Params() { }
|
| 818 |
+
|
| 819 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 820 |
+
CUTLASS_HOST_DEVICE
|
| 821 |
+
Params(Layout const &layout)
|
| 822 |
+
: params_(layout::PitchLinear(layout.stride(0))){};
|
| 823 |
+
};
|
| 824 |
+
|
| 825 |
+
private:
|
| 826 |
+
//
|
| 827 |
+
// Data members
|
| 828 |
+
//
|
| 829 |
+
|
| 830 |
+
/// Underlying pitch-linear tile iterator
|
| 831 |
+
UnderlyingIterator iterator_;
|
| 832 |
+
|
| 833 |
+
public:
|
| 834 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 835 |
+
/// and thread ID
|
| 836 |
+
CUTLASS_HOST_DEVICE
|
| 837 |
+
EllPredicatedTileAccessIterator(
|
| 838 |
+
///< Precomputed parameters object
|
| 839 |
+
Params const ¶ms,
|
| 840 |
+
///< Pointer to start of tensor
|
| 841 |
+
Pointer pointer,
|
| 842 |
+
///< Extent of tensor
|
| 843 |
+
TensorCoord extent,
|
| 844 |
+
///< ID of each participating thread
|
| 845 |
+
int thread_id,
|
| 846 |
+
///< Initial offset of threadblock
|
| 847 |
+
TensorCoord const &threadblock_offset)
|
| 848 |
+
: iterator_(params.params_, pointer,
|
| 849 |
+
layout::PitchLinearCoord(extent.column(), extent.row()),
|
| 850 |
+
thread_id,
|
| 851 |
+
layout::PitchLinearCoord(threadblock_offset.column(),
|
| 852 |
+
threadblock_offset.row())) {}
|
| 853 |
+
|
| 854 |
+
/// Construct a EllPredicatedTileAccessIterator with zero threadblock offset
|
| 855 |
+
CUTLASS_HOST_DEVICE
|
| 856 |
+
EllPredicatedTileAccessIterator(
|
| 857 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 858 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 859 |
+
TensorCoord extent, ///< Extent of tensor
|
| 860 |
+
int thread_id ///< ID of each participating thread
|
| 861 |
+
)
|
| 862 |
+
: EllPredicatedTileAccessIterator(params, pointer, extent, thread_id,
|
| 863 |
+
make_Coord(0, 0)) {}
|
| 864 |
+
|
| 865 |
+
/// Overrides the internal iteration index
|
| 866 |
+
CUTLASS_HOST_DEVICE
|
| 867 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 868 |
+
|
| 869 |
+
/// Adds a pointer offset in units of Element
|
| 870 |
+
CUTLASS_HOST_DEVICE
|
| 871 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 872 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 873 |
+
}
|
| 874 |
+
|
| 875 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 876 |
+
/// tiles
|
| 877 |
+
CUTLASS_HOST_DEVICE
|
| 878 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 879 |
+
iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
|
| 880 |
+
}
|
| 881 |
+
|
| 882 |
+
/// Returns a pointer
|
| 883 |
+
CUTLASS_HOST_DEVICE
|
| 884 |
+
AccessType *get() const {
|
| 885 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 886 |
+
}
|
| 887 |
+
|
| 888 |
+
CUTLASS_HOST_DEVICE
|
| 889 |
+
int get_k() const {
|
| 890 |
+
return iterator_.get_k();
|
| 891 |
+
}
|
| 892 |
+
|
| 893 |
+
CUTLASS_HOST_DEVICE
|
| 894 |
+
int get_stride() const {
|
| 895 |
+
return iterator_.get_stride();
|
| 896 |
+
}
|
| 897 |
+
|
| 898 |
+
/// Advances to the next tile in memory.
|
| 899 |
+
///
|
| 900 |
+
/// The first time this method is called, predicates are updated, and the
|
| 901 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 902 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 903 |
+
/// pointer.
|
| 904 |
+
CUTLASS_HOST_DEVICE
|
| 905 |
+
EllPredicatedTileAccessIterator &operator++() {
|
| 906 |
+
++iterator_;
|
| 907 |
+
return *this;
|
| 908 |
+
}
|
| 909 |
+
|
| 910 |
+
/// Advances to the next tile in memory.
|
| 911 |
+
///
|
| 912 |
+
/// The first time this method is called, predicates are updated, and the
|
| 913 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 914 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 915 |
+
/// pointer.
|
| 916 |
+
CUTLASS_HOST_DEVICE
|
| 917 |
+
EllPredicatedTileAccessIterator operator++(int) {
|
| 918 |
+
EllPredicatedTileAccessIterator self(*this);
|
| 919 |
+
operator++();
|
| 920 |
+
return self;
|
| 921 |
+
}
|
| 922 |
+
|
| 923 |
+
/// Clears the predicate set efficiently
|
| 924 |
+
CUTLASS_HOST_DEVICE
|
| 925 |
+
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
| 926 |
+
|
| 927 |
+
/// Clears the predicate set efficiently
|
| 928 |
+
CUTLASS_HOST_DEVICE
|
| 929 |
+
void enable_mask() { iterator_.enable_mask(); }
|
| 930 |
+
|
| 931 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 932 |
+
CUTLASS_HOST_DEVICE
|
| 933 |
+
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
| 934 |
+
|
| 935 |
+
/// Gets the mask
|
| 936 |
+
CUTLASS_HOST_DEVICE
|
| 937 |
+
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
| 938 |
+
|
| 939 |
+
/// add mask for small tiles in ELL
|
| 940 |
+
CUTLASS_DEVICE
|
| 941 |
+
void ell_add_mask(int blocksize) {
|
| 942 |
+
iterator_.ell_add_mask(blocksize);
|
| 943 |
+
}
|
| 944 |
+
|
| 945 |
+
/// Returns whether access is valid or not
|
| 946 |
+
CUTLASS_HOST_DEVICE
|
| 947 |
+
bool valid() {
|
| 948 |
+
return iterator_.valid();
|
| 949 |
+
}
|
| 950 |
+
};
|
| 951 |
+
|
| 952 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 953 |
+
|
| 954 |
+
/// Specialization of EllPredicatedTileAccessIterator for column-major interleaved data.
|
| 955 |
+
/// It is mapped to the congruous layout.
|
| 956 |
+
///
|
| 957 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 958 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 959 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 960 |
+
/// MaskedTileIteratorConcept
|
| 961 |
+
///
|
| 962 |
+
|
| 963 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 964 |
+
typename ThreadMap_, typename AccessType_, int InterleavedK>
|
| 965 |
+
class EllPredicatedTileAccessIterator<Shape_, Element_,
|
| 966 |
+
layout::ColumnMajorInterleaved<InterleavedK>,
|
| 967 |
+
AdvanceRank, ThreadMap_, AccessType_> {
|
| 968 |
+
public:
|
| 969 |
+
static_assert(
|
| 970 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 971 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 972 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 973 |
+
|
| 974 |
+
using Shape = Shape_;
|
| 975 |
+
using Element = Element_;
|
| 976 |
+
static int const kInterleavedK = InterleavedK;
|
| 977 |
+
using Layout = layout::ColumnMajorInterleaved<kInterleavedK>;
|
| 978 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 979 |
+
using ThreadMap = ThreadMap_;
|
| 980 |
+
using AccessType = AccessType_;
|
| 981 |
+
|
| 982 |
+
using Index = typename Layout::Index;
|
| 983 |
+
using LongIndex = typename Layout::LongIndex;
|
| 984 |
+
|
| 985 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 986 |
+
using TensorView = TensorView<Element, Layout>;
|
| 987 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 988 |
+
|
| 989 |
+
using Pointer = Element *;
|
| 990 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 991 |
+
|
| 992 |
+
using UnderlyingIterator = EllPredicatedTileAccessIterator<
|
| 993 |
+
layout::PitchLinearShape<Shape::kRow * kInterleavedK,
|
| 994 |
+
Shape::kColumn / kInterleavedK>,
|
| 995 |
+
Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap,
|
| 996 |
+
AccessType>;
|
| 997 |
+
|
| 998 |
+
static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
|
| 999 |
+
|
| 1000 |
+
/// Predicate vector stores mask to guard accesses
|
| 1001 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 1002 |
+
|
| 1003 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 1004 |
+
class Params {
|
| 1005 |
+
private:
|
| 1006 |
+
friend EllPredicatedTileAccessIterator;
|
| 1007 |
+
|
| 1008 |
+
/// Parameters object
|
| 1009 |
+
typename UnderlyingIterator::Params params_;
|
| 1010 |
+
|
| 1011 |
+
public:
|
| 1012 |
+
CUTLASS_HOST_DEVICE
|
| 1013 |
+
Params() {}
|
| 1014 |
+
|
| 1015 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 1016 |
+
CUTLASS_HOST_DEVICE
|
| 1017 |
+
Params(Layout const &layout)
|
| 1018 |
+
: params_(layout::PitchLinear(layout.stride(0))) {}
|
| 1019 |
+
};
|
| 1020 |
+
|
| 1021 |
+
private:
|
| 1022 |
+
//
|
| 1023 |
+
// Data members
|
| 1024 |
+
//
|
| 1025 |
+
|
| 1026 |
+
/// Underlying pitch-linear tile iterator
|
| 1027 |
+
UnderlyingIterator iterator_;
|
| 1028 |
+
|
| 1029 |
+
public:
|
| 1030 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 1031 |
+
/// and thread ID
|
| 1032 |
+
CUTLASS_HOST_DEVICE
|
| 1033 |
+
EllPredicatedTileAccessIterator(
|
| 1034 |
+
/// Precomputed parameters object
|
| 1035 |
+
Params const ¶ms,
|
| 1036 |
+
/// Pointer to start of tensor
|
| 1037 |
+
Pointer pointer,
|
| 1038 |
+
/// Extent of tensor
|
| 1039 |
+
TensorCoord extent,
|
| 1040 |
+
/// ID of each participating thread
|
| 1041 |
+
int thread_id,
|
| 1042 |
+
/// Initial offset of threadblock
|
| 1043 |
+
TensorCoord const &threadblock_offset)
|
| 1044 |
+
: iterator_(params.params_, pointer,
|
| 1045 |
+
layout::PitchLinearCoord(extent.row() * kInterleavedK,
|
| 1046 |
+
extent.column() / kInterleavedK),
|
| 1047 |
+
thread_id,
|
| 1048 |
+
layout::PitchLinearCoord(
|
| 1049 |
+
threadblock_offset.row() * kInterleavedK,
|
| 1050 |
+
threadblock_offset.column() / kInterleavedK)) {}
|
| 1051 |
+
|
| 1052 |
+
/// Construct a EllPredicatedTileAccessIterator with zero threadblock offset
|
| 1053 |
+
CUTLASS_HOST_DEVICE
|
| 1054 |
+
EllPredicatedTileAccessIterator(
|
| 1055 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 1056 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 1057 |
+
TensorCoord extent, ///< Extent of tensor
|
| 1058 |
+
int thread_id ///< ID of each participating thread
|
| 1059 |
+
)
|
| 1060 |
+
: EllPredicatedTileAccessIterator(params, pointer, extent, thread_id,
|
| 1061 |
+
make_Coord(0, 0)) {}
|
| 1062 |
+
|
| 1063 |
+
/// Overrides the internal iteration index
|
| 1064 |
+
CUTLASS_HOST_DEVICE
|
| 1065 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 1066 |
+
|
| 1067 |
+
/// Adds a pointer offset in units of Element
|
| 1068 |
+
CUTLASS_HOST_DEVICE
|
| 1069 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1070 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 1071 |
+
}
|
| 1072 |
+
|
| 1073 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 1074 |
+
/// tiles
|
| 1075 |
+
CUTLASS_HOST_DEVICE
|
| 1076 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 1077 |
+
iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
|
| 1078 |
+
}
|
| 1079 |
+
|
| 1080 |
+
/// Returns a pointer
|
| 1081 |
+
CUTLASS_HOST_DEVICE
|
| 1082 |
+
AccessType *get() const {
|
| 1083 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 1084 |
+
}
|
| 1085 |
+
|
| 1086 |
+
CUTLASS_HOST_DEVICE
|
| 1087 |
+
int get_k() const {
|
| 1088 |
+
return iterator_.get_k();
|
| 1089 |
+
}
|
| 1090 |
+
|
| 1091 |
+
CUTLASS_HOST_DEVICE
|
| 1092 |
+
int get_stride() const {
|
| 1093 |
+
return iterator_.get_stride();
|
| 1094 |
+
}
|
| 1095 |
+
|
| 1096 |
+
/// Advances to the next tile in memory.
|
| 1097 |
+
///
|
| 1098 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1099 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1100 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1101 |
+
/// pointer.
|
| 1102 |
+
CUTLASS_HOST_DEVICE
|
| 1103 |
+
EllPredicatedTileAccessIterator &operator++() {
|
| 1104 |
+
++iterator_;
|
| 1105 |
+
return *this;
|
| 1106 |
+
}
|
| 1107 |
+
|
| 1108 |
+
/// Advances to the next tile in memory.
|
| 1109 |
+
///
|
| 1110 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1111 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1112 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1113 |
+
/// pointer.
|
| 1114 |
+
CUTLASS_HOST_DEVICE
|
| 1115 |
+
EllPredicatedTileAccessIterator operator++(int) {
|
| 1116 |
+
EllPredicatedTileAccessIterator self(*this);
|
| 1117 |
+
operator++();
|
| 1118 |
+
return self;
|
| 1119 |
+
}
|
| 1120 |
+
|
| 1121 |
+
/// Clears the predicate set efficiently
|
| 1122 |
+
CUTLASS_HOST_DEVICE
|
| 1123 |
+
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
| 1124 |
+
|
| 1125 |
+
/// Clears the predicate set efficiently
|
| 1126 |
+
CUTLASS_HOST_DEVICE
|
| 1127 |
+
void enable_mask() { iterator_.enable_mask(); }
|
| 1128 |
+
|
| 1129 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 1130 |
+
CUTLASS_HOST_DEVICE
|
| 1131 |
+
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
| 1132 |
+
|
| 1133 |
+
/// Gets the mask
|
| 1134 |
+
CUTLASS_HOST_DEVICE
|
| 1135 |
+
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
| 1136 |
+
|
| 1137 |
+
/// add mask for small tiles in ELL
|
| 1138 |
+
CUTLASS_DEVICE
|
| 1139 |
+
void ell_add_mask(int blocksize) {
|
| 1140 |
+
iterator_.ell_add_mask(blocksize);
|
| 1141 |
+
}
|
| 1142 |
+
|
| 1143 |
+
/// Returns whether access is valid or not
|
| 1144 |
+
CUTLASS_HOST_DEVICE
|
| 1145 |
+
bool valid() { return iterator_.valid(); }
|
| 1146 |
+
};
|
| 1147 |
+
|
| 1148 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1149 |
+
|
| 1150 |
+
/// Specialization of EllPredicatedTileAccessIterator for row-major interleaved data.
|
| 1151 |
+
/// It is mapped to the congruous layout.
|
| 1152 |
+
///
|
| 1153 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 1154 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 1155 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 1156 |
+
/// MaskedTileIteratorConcept
|
| 1157 |
+
///
|
| 1158 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 1159 |
+
typename ThreadMap_, typename AccessType_, int InterleavedK>
|
| 1160 |
+
class EllPredicatedTileAccessIterator<Shape_, Element_,
|
| 1161 |
+
layout::RowMajorInterleaved<InterleavedK>,
|
| 1162 |
+
AdvanceRank, ThreadMap_, AccessType_> {
|
| 1163 |
+
public:
|
| 1164 |
+
static_assert(
|
| 1165 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 1166 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 1167 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 1168 |
+
|
| 1169 |
+
using Shape = Shape_;
|
| 1170 |
+
using Element = Element_;
|
| 1171 |
+
static int const kInterleavedK = InterleavedK;
|
| 1172 |
+
using Layout = layout::RowMajorInterleaved<kInterleavedK>;
|
| 1173 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 1174 |
+
using ThreadMap = ThreadMap_;
|
| 1175 |
+
using AccessType = AccessType_;
|
| 1176 |
+
|
| 1177 |
+
using Index = typename Layout::Index;
|
| 1178 |
+
using LongIndex = typename Layout::LongIndex;
|
| 1179 |
+
|
| 1180 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 1181 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1182 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 1183 |
+
|
| 1184 |
+
using Pointer = Element *;
|
| 1185 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 1186 |
+
|
| 1187 |
+
using UnderlyingIterator = EllPredicatedTileAccessIterator<
|
| 1188 |
+
layout::PitchLinearShape<Shape::kColumn * kInterleavedK,
|
| 1189 |
+
Shape::kRow / kInterleavedK>,
|
| 1190 |
+
Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap,
|
| 1191 |
+
AccessType>;
|
| 1192 |
+
|
| 1193 |
+
|
| 1194 |
+
static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
|
| 1195 |
+
|
| 1196 |
+
/// Predicate vector stores mask to guard accesses
|
| 1197 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 1198 |
+
|
| 1199 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 1200 |
+
class Params {
|
| 1201 |
+
private:
|
| 1202 |
+
friend EllPredicatedTileAccessIterator;
|
| 1203 |
+
|
| 1204 |
+
/// Parameters object
|
| 1205 |
+
typename UnderlyingIterator::Params params_;
|
| 1206 |
+
|
| 1207 |
+
public:
|
| 1208 |
+
CUTLASS_HOST_DEVICE
|
| 1209 |
+
Params() {}
|
| 1210 |
+
|
| 1211 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 1212 |
+
CUTLASS_HOST_DEVICE
|
| 1213 |
+
Params(Layout const &layout)
|
| 1214 |
+
: params_(layout::PitchLinear(layout.stride(0))) {}
|
| 1215 |
+
};
|
| 1216 |
+
|
| 1217 |
+
private:
|
| 1218 |
+
//
|
| 1219 |
+
// Data members
|
| 1220 |
+
//
|
| 1221 |
+
|
| 1222 |
+
/// Underlying pitch-linear tile iterator
|
| 1223 |
+
UnderlyingIterator iterator_;
|
| 1224 |
+
|
| 1225 |
+
public:
|
| 1226 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 1227 |
+
/// and thread ID
|
| 1228 |
+
CUTLASS_HOST_DEVICE
|
| 1229 |
+
EllPredicatedTileAccessIterator(
|
| 1230 |
+
/// Precomputed parameters object
|
| 1231 |
+
Params const ¶ms,
|
| 1232 |
+
/// Pointer to start of tensor
|
| 1233 |
+
Pointer pointer,
|
| 1234 |
+
/// Extent of tensor
|
| 1235 |
+
TensorCoord extent,
|
| 1236 |
+
/// ID of each participating thread
|
| 1237 |
+
int thread_id,
|
| 1238 |
+
/// Initial offset of threadblock
|
| 1239 |
+
TensorCoord const &threadblock_offset)
|
| 1240 |
+
: iterator_(params.params_, pointer,
|
| 1241 |
+
layout::PitchLinearCoord(extent.column() * kInterleavedK,
|
| 1242 |
+
extent.row() / kInterleavedK),
|
| 1243 |
+
thread_id,
|
| 1244 |
+
layout::PitchLinearCoord(
|
| 1245 |
+
threadblock_offset.column() * kInterleavedK,
|
| 1246 |
+
threadblock_offset.row() / kInterleavedK)) {}
|
| 1247 |
+
|
| 1248 |
+
/// Construct a EllPredicatedTileAccessIterator with zero threadblock offset
|
| 1249 |
+
CUTLASS_HOST_DEVICE
|
| 1250 |
+
EllPredicatedTileAccessIterator(
|
| 1251 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 1252 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 1253 |
+
TensorCoord extent, ///< Extent of tensor
|
| 1254 |
+
int thread_id ///< ID of each participating thread
|
| 1255 |
+
)
|
| 1256 |
+
: EllPredicatedTileAccessIterator(params, pointer, extent, thread_id,
|
| 1257 |
+
make_Coord(0, 0)) {}
|
| 1258 |
+
|
| 1259 |
+
/// Overrides the internal iteration index
|
| 1260 |
+
CUTLASS_HOST_DEVICE
|
| 1261 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 1262 |
+
|
| 1263 |
+
/// Adds a pointer offset in units of Element
|
| 1264 |
+
CUTLASS_HOST_DEVICE
|
| 1265 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1266 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 1267 |
+
}
|
| 1268 |
+
|
| 1269 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 1270 |
+
/// tiles
|
| 1271 |
+
CUTLASS_HOST_DEVICE
|
| 1272 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 1273 |
+
iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
|
| 1274 |
+
}
|
| 1275 |
+
|
| 1276 |
+
/// Returns a pointer
|
| 1277 |
+
CUTLASS_HOST_DEVICE
|
| 1278 |
+
AccessType *get() const {
|
| 1279 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 1280 |
+
}
|
| 1281 |
+
|
| 1282 |
+
CUTLASS_HOST_DEVICE
|
| 1283 |
+
int get_k() const {
|
| 1284 |
+
return iterator_.get_k();
|
| 1285 |
+
}
|
| 1286 |
+
|
| 1287 |
+
CUTLASS_HOST_DEVICE
|
| 1288 |
+
int get_stride() const {
|
| 1289 |
+
return iterator_.get_stride();
|
| 1290 |
+
}
|
| 1291 |
+
|
| 1292 |
+
/// Advances to the next tile in memory.
|
| 1293 |
+
///
|
| 1294 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1295 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1296 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1297 |
+
/// pointer.
|
| 1298 |
+
CUTLASS_HOST_DEVICE
|
| 1299 |
+
EllPredicatedTileAccessIterator &operator++() {
|
| 1300 |
+
++iterator_;
|
| 1301 |
+
return *this;
|
| 1302 |
+
}
|
| 1303 |
+
|
| 1304 |
+
/// Advances to the next tile in memory.
|
| 1305 |
+
///
|
| 1306 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1307 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1308 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1309 |
+
/// pointer.
|
| 1310 |
+
CUTLASS_HOST_DEVICE
|
| 1311 |
+
EllPredicatedTileAccessIterator operator++(int) {
|
| 1312 |
+
EllPredicatedTileAccessIterator self(*this);
|
| 1313 |
+
operator++();
|
| 1314 |
+
return self;
|
| 1315 |
+
}
|
| 1316 |
+
|
| 1317 |
+
/// Clears the predicate set efficiently
|
| 1318 |
+
CUTLASS_HOST_DEVICE
|
| 1319 |
+
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
| 1320 |
+
|
| 1321 |
+
/// Clears the predicate set efficiently
|
| 1322 |
+
CUTLASS_HOST_DEVICE
|
| 1323 |
+
void enable_mask() { iterator_.enable_mask(); }
|
| 1324 |
+
|
| 1325 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 1326 |
+
CUTLASS_HOST_DEVICE
|
| 1327 |
+
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
| 1328 |
+
|
| 1329 |
+
/// Gets the mask
|
| 1330 |
+
CUTLASS_HOST_DEVICE
|
| 1331 |
+
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
| 1332 |
+
|
| 1333 |
+
/// add mask for small tiles in ELL
|
| 1334 |
+
CUTLASS_DEVICE
|
| 1335 |
+
void ell_add_mask(int blocksize) {
|
| 1336 |
+
iterator_.ell_add_mask(blocksize);
|
| 1337 |
+
}
|
| 1338 |
+
|
| 1339 |
+
/// Returns whether access is valid or not
|
| 1340 |
+
CUTLASS_HOST_DEVICE
|
| 1341 |
+
bool valid() { return iterator_.valid(); }
|
| 1342 |
+
};
|
| 1343 |
+
|
| 1344 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1345 |
+
|
| 1346 |
+
} // namespace threadblock
|
| 1347 |
+
} // namespace transform
|
| 1348 |
+
} // namespace cutlass
|
| 1349 |
+
|
| 1350 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h
ADDED
|
@@ -0,0 +1,1315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Ell iterator for Blocked-Ell matrix (ellValue matrix) used with EllMmaPipelined
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/arch/memory.h"
|
| 38 |
+
#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h"
|
| 39 |
+
|
| 40 |
+
#include "cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h"
|
| 41 |
+
#include "cutlass/transform/threadblock/ell_iterator.h"
|
| 42 |
+
|
| 43 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
namespace cutlass {
|
| 46 |
+
namespace transform {
|
| 47 |
+
namespace threadblock {
|
| 48 |
+
|
| 49 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 50 |
+
|
| 51 |
+
/// EllPredicatedTileIterator
|
| 52 |
+
///
|
| 53 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 54 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 55 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 56 |
+
/// MaskedTileIteratorConcept
|
| 57 |
+
///
|
| 58 |
+
/// Regular tile iterator using a precomputed control structure to minimize register liveness
|
| 59 |
+
/// and integer arithmetic.
|
| 60 |
+
///
|
| 61 |
+
/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed.
|
| 62 |
+
///
|
| 63 |
+
/// Base pointer and tensor extents may be specified at the time the iterator is constructed.
|
| 64 |
+
/// Subsequently, they are assumed to be immutable.
|
| 65 |
+
///
|
| 66 |
+
/// Adding a logical coordinate offset may be performed at the time the iterator is constructed.
|
| 67 |
+
/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive.
|
| 68 |
+
///
|
| 69 |
+
/// Visitation order is intended to first visit a "residual" tile that may be partially full in
|
| 70 |
+
/// both the advance dimension and the steady-state dimension. This is assumed to be the last
|
| 71 |
+
/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to
|
| 72 |
+
/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent
|
| 73 |
+
/// accesses may be performed without updating internal predicates and are efficient in terms of
|
| 74 |
+
/// live register state and pointer arithmetic instructions.
|
| 75 |
+
///
|
| 76 |
+
/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once
|
| 77 |
+
/// outside any looping structure to minimize integer arithmetic.
|
| 78 |
+
///
|
| 79 |
+
/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing
|
| 80 |
+
/// the iterator.
|
| 81 |
+
///
|
| 82 |
+
///
|
| 83 |
+
/// Example:
|
| 84 |
+
///
|
| 85 |
+
/// An efficient pipeline structure may be constructed as follows:
|
| 86 |
+
///
|
| 87 |
+
// template <typename Iterator>
|
| 88 |
+
// __global__ void kernel(
|
| 89 |
+
// typename Iterator::Params params,
|
| 90 |
+
// typename Iterator::Element *ptr,
|
| 91 |
+
// TensorCoord extent) {
|
| 92 |
+
//
|
| 93 |
+
// typename Iterator::Fragment fragment;
|
| 94 |
+
//
|
| 95 |
+
// TensorCoord threadblock_offset(0, 0);
|
| 96 |
+
//
|
| 97 |
+
// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets);
|
| 98 |
+
//
|
| 99 |
+
//
|
| 100 |
+
// fragment = *iter; // load "residue" tile first
|
| 101 |
+
// ++iter; // advance to first "steady state" tile and update internal masks
|
| 102 |
+
//
|
| 103 |
+
//
|
| 104 |
+
// #pragma unroll
|
| 105 |
+
// for (int i = Remaining - 1; i >= 0; --i) {
|
| 106 |
+
//
|
| 107 |
+
// f(fragment);
|
| 108 |
+
//
|
| 109 |
+
// if (!i) {
|
| 110 |
+
// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs.
|
| 111 |
+
// }
|
| 112 |
+
//
|
| 113 |
+
// fragment = *iter; // load tile during "steady state" phase
|
| 114 |
+
// ++iter; // advance to next tile - lightweight due to steady-state masks
|
| 115 |
+
// }
|
| 116 |
+
// }
|
| 117 |
+
//
|
| 118 |
+
// void host(TensorView<Element, 2, layout::PitchLinear> view) {
|
| 119 |
+
//
|
| 120 |
+
// using Iterator = transform::threadblock::EllPredicatedTileIterator;
|
| 121 |
+
//
|
| 122 |
+
// typename Iterator::Params params(view.layout());
|
| 123 |
+
//
|
| 124 |
+
// kernel<Iterator>(params, view.data());
|
| 125 |
+
// }
|
| 126 |
+
///
|
| 127 |
+
///
|
| 128 |
+
template <
|
| 129 |
+
typename Shape,
|
| 130 |
+
typename Element,
|
| 131 |
+
typename Layout,
|
| 132 |
+
int AdvanceRank,
|
| 133 |
+
typename ThreadMap,
|
| 134 |
+
int AccessSize = ThreadMap::kElementsPerAccess
|
| 135 |
+
>
|
| 136 |
+
class EllPredicatedTileIterator;
|
| 137 |
+
|
| 138 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 139 |
+
|
| 140 |
+
/// Specialization of EllPredicatedTileIterator for pitch-linear data.
|
| 141 |
+
///
|
| 142 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 143 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 144 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 145 |
+
/// MaskedTileIteratorConcept
|
| 146 |
+
///
|
| 147 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 148 |
+
typename ThreadMap_, int AccessSize>
|
| 149 |
+
class EllPredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
|
| 150 |
+
ThreadMap_, AccessSize> {
|
| 151 |
+
public:
|
| 152 |
+
static_assert(
|
| 153 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 154 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 155 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 156 |
+
|
| 157 |
+
using Shape = Shape_;
|
| 158 |
+
using Element = Element_;
|
| 159 |
+
using Layout = layout::PitchLinear;
|
| 160 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 161 |
+
using ThreadMap = ThreadMap_;
|
| 162 |
+
|
| 163 |
+
using Index = typename Layout::Index;
|
| 164 |
+
using LongIndex = typename Layout::LongIndex;
|
| 165 |
+
|
| 166 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 167 |
+
using TensorView = TensorView<Element, Layout>;
|
| 168 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 169 |
+
|
| 170 |
+
using Pointer = Element *;
|
| 171 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 172 |
+
|
| 173 |
+
/// Type used for internal memory accesses
|
| 174 |
+
using AccessType = AlignedArray<Element, AccessSize, (AccessSize * sizeof_bits<Element>::value / 8)>;
|
| 175 |
+
|
| 176 |
+
/// Underlying iterator to compute the addresses
|
| 177 |
+
using TileAccessIterator =
|
| 178 |
+
EllPredicatedTileAccessIterator<Shape, Element, Layout, kAdvanceRank,
|
| 179 |
+
ThreadMap, AccessType>;
|
| 180 |
+
|
| 181 |
+
static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
|
| 182 |
+
|
| 183 |
+
/// Fragment object to be loaded or stored
|
| 184 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
|
| 185 |
+
ThreadMap::kElementsPerAccess>;
|
| 186 |
+
|
| 187 |
+
/// Predicate vector stores mask to guard accesses
|
| 188 |
+
using Mask = typename TileAccessIterator::Mask;
|
| 189 |
+
|
| 190 |
+
/// Iterator for ELL storage
|
| 191 |
+
using EllIterator = typename cutlass::transform::threadblock::ell::Iterator;
|
| 192 |
+
|
| 193 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 194 |
+
class Params {
|
| 195 |
+
public:
|
| 196 |
+
friend EllPredicatedTileIterator;
|
| 197 |
+
|
| 198 |
+
private:
|
| 199 |
+
/// Parameters object
|
| 200 |
+
typename TileAccessIterator::Params params_;
|
| 201 |
+
|
| 202 |
+
public:
|
| 203 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 204 |
+
CUTLASS_HOST_DEVICE
|
| 205 |
+
Params(Layout const &layout) : params_(layout) { }
|
| 206 |
+
|
| 207 |
+
CUTLASS_HOST_DEVICE
|
| 208 |
+
Params() { }
|
| 209 |
+
};
|
| 210 |
+
|
| 211 |
+
private:
|
| 212 |
+
/// Internal pointer type permits fast address arithmetic
|
| 213 |
+
using BytePointer = char *;
|
| 214 |
+
|
| 215 |
+
private:
|
| 216 |
+
//
|
| 217 |
+
// Data members
|
| 218 |
+
//
|
| 219 |
+
|
| 220 |
+
/// Data member to the tile access iterator
|
| 221 |
+
TileAccessIterator address_iterator_;
|
| 222 |
+
|
| 223 |
+
public:
|
| 224 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 225 |
+
/// and thread ID
|
| 226 |
+
CUTLASS_HOST_DEVICE
|
| 227 |
+
EllPredicatedTileIterator(
|
| 228 |
+
/// Precomputed parameters object
|
| 229 |
+
Params const ¶ms,
|
| 230 |
+
/// Pointer to start of tensor
|
| 231 |
+
Pointer pointer,
|
| 232 |
+
/// Extent of tensor
|
| 233 |
+
TensorCoord extent,
|
| 234 |
+
/// ID of each participating thread
|
| 235 |
+
int thread_id,
|
| 236 |
+
/// Initial offset of threadblock
|
| 237 |
+
TensorCoord const &threadblock_offset)
|
| 238 |
+
: address_iterator_(params.params_, pointer, extent, thread_id,
|
| 239 |
+
threadblock_offset) {}
|
| 240 |
+
|
| 241 |
+
/// Construct a EllPredicatedTileIterator with zero threadblock offset
|
| 242 |
+
CUTLASS_HOST_DEVICE
|
| 243 |
+
EllPredicatedTileIterator(
|
| 244 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 245 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 246 |
+
TensorCoord extent, ///< Extent of tensor
|
| 247 |
+
int thread_id ///< ID of each participating thread
|
| 248 |
+
)
|
| 249 |
+
: EllPredicatedTileIterator(params, pointer, extent, thread_id,
|
| 250 |
+
make_Coord(0, 0)) {}
|
| 251 |
+
|
| 252 |
+
/// Adds a pointer offset in units of Element
|
| 253 |
+
CUTLASS_HOST_DEVICE
|
| 254 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 255 |
+
address_iterator_.add_pointer_offset(pointer_offset);
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
/// Advances to the next tile in memory.
|
| 259 |
+
///
|
| 260 |
+
/// The first time this method is called, predicates are updated, and the
|
| 261 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 262 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 263 |
+
/// pointer.
|
| 264 |
+
CUTLASS_HOST_DEVICE
|
| 265 |
+
EllPredicatedTileIterator &operator++() {
|
| 266 |
+
if (kAdvanceRank)
|
| 267 |
+
address_iterator_.add_tile_offset({0, 1});
|
| 268 |
+
else
|
| 269 |
+
address_iterator_.add_tile_offset({1, 0});
|
| 270 |
+
|
| 271 |
+
return *this;
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
/// Advances to the next tile in memory.
|
| 275 |
+
///
|
| 276 |
+
/// The first time this method is called, predicates are updated, and the
|
| 277 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 278 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 279 |
+
/// pointer.
|
| 280 |
+
CUTLASS_HOST_DEVICE
|
| 281 |
+
EllPredicatedTileIterator operator++(int) {
|
| 282 |
+
EllPredicatedTileIterator self(*this);
|
| 283 |
+
operator++();
|
| 284 |
+
return self;
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
/// Returns a stride
|
| 288 |
+
CUTLASS_HOST_DEVICE
|
| 289 |
+
int get_stride() const { return address_iterator_.get_stride(); }
|
| 290 |
+
|
| 291 |
+
/// Clears the predicate set efficiently
|
| 292 |
+
CUTLASS_HOST_DEVICE
|
| 293 |
+
void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
|
| 294 |
+
|
| 295 |
+
/// Clears the predicate set efficiently
|
| 296 |
+
CUTLASS_HOST_DEVICE
|
| 297 |
+
void enable_mask() { address_iterator_.enable_mask(); }
|
| 298 |
+
|
| 299 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 300 |
+
CUTLASS_HOST_DEVICE
|
| 301 |
+
void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); }
|
| 302 |
+
|
| 303 |
+
/// Gets the mask
|
| 304 |
+
CUTLASS_HOST_DEVICE
|
| 305 |
+
void get_mask(Mask &mask) { address_iterator_.get_mask(mask); }
|
| 306 |
+
|
| 307 |
+
/// add mask for small tiles in ELL
|
| 308 |
+
CUTLASS_HOST_DEVICE
|
| 309 |
+
void ell_add_mask(int blocksize) { address_iterator_.ell_add_mask(blocksize); }
|
| 310 |
+
|
| 311 |
+
CUTLASS_DEVICE
|
| 312 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 313 |
+
load_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
CUTLASS_DEVICE
|
| 317 |
+
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
| 318 |
+
|
| 319 |
+
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
| 320 |
+
|
| 321 |
+
CUTLASS_PRAGMA_UNROLL
|
| 322 |
+
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
| 323 |
+
CUTLASS_PRAGMA_UNROLL
|
| 324 |
+
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
| 325 |
+
|
| 326 |
+
CUTLASS_PRAGMA_UNROLL
|
| 327 |
+
for (int v = 0; v < kAccessesPerVector; ++v) {
|
| 328 |
+
|
| 329 |
+
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
| 330 |
+
|
| 331 |
+
address_iterator_.set_iteration_index(idx);
|
| 332 |
+
char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + byte_offset;
|
| 333 |
+
|
| 334 |
+
AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
|
| 335 |
+
|
| 336 |
+
cutlass::arch::global_load<AccessType,
|
| 337 |
+
sizeof(AccessType)
|
| 338 |
+
>(
|
| 339 |
+
frag_ptr[idx], access_ptr, address_iterator_.valid());
|
| 340 |
+
|
| 341 |
+
++address_iterator_;
|
| 342 |
+
}
|
| 343 |
+
}
|
| 344 |
+
}
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
/// Loads a fragment from memory
|
| 348 |
+
CUTLASS_DEVICE
|
| 349 |
+
void load(Fragment &frag) { load_with_byte_offset(frag, 0); }
|
| 350 |
+
|
| 351 |
+
CUTLASS_DEVICE
|
| 352 |
+
void load_with_ell_index(Fragment &frag, EllIterator &ell_iter) {
|
| 353 |
+
|
| 354 |
+
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
| 355 |
+
|
| 356 |
+
CUTLASS_PRAGMA_UNROLL
|
| 357 |
+
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
| 358 |
+
CUTLASS_PRAGMA_UNROLL
|
| 359 |
+
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
| 360 |
+
CUTLASS_PRAGMA_UNROLL
|
| 361 |
+
for (int v = 0; v < kAccessesPerVector; ++v) {
|
| 362 |
+
|
| 363 |
+
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
| 364 |
+
address_iterator_.set_iteration_index(idx);
|
| 365 |
+
LongIndex ell_offset = 0;
|
| 366 |
+
|
| 367 |
+
int k_offset = address_iterator_.get_k();
|
| 368 |
+
ell_offset = ell_iter.get_offset(k_offset) * sizeof(Element);
|
| 369 |
+
|
| 370 |
+
char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + ell_offset;
|
| 371 |
+
|
| 372 |
+
AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
|
| 373 |
+
|
| 374 |
+
bool is_valid = address_iterator_.valid();
|
| 375 |
+
is_valid = is_valid && (ell_offset >= 0);
|
| 376 |
+
|
| 377 |
+
cutlass::arch::global_load<AccessType,
|
| 378 |
+
sizeof(AccessType)
|
| 379 |
+
>(
|
| 380 |
+
frag_ptr[idx], access_ptr, is_valid);
|
| 381 |
+
|
| 382 |
+
++address_iterator_;
|
| 383 |
+
}
|
| 384 |
+
}
|
| 385 |
+
}
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
CUTLASS_DEVICE
|
| 389 |
+
void load_with_ell_index_fast(Fragment &frag, EllIterator &ell_iter) {
|
| 390 |
+
|
| 391 |
+
LongIndex ell_offset = ell_iter.get_offset_fast() * sizeof(Element);
|
| 392 |
+
|
| 393 |
+
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
| 394 |
+
|
| 395 |
+
CUTLASS_PRAGMA_UNROLL
|
| 396 |
+
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
| 397 |
+
CUTLASS_PRAGMA_UNROLL
|
| 398 |
+
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
| 399 |
+
|
| 400 |
+
CUTLASS_PRAGMA_UNROLL
|
| 401 |
+
for (int v = 0; v < kAccessesPerVector; ++v) {
|
| 402 |
+
|
| 403 |
+
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
| 404 |
+
|
| 405 |
+
address_iterator_.set_iteration_index(idx);
|
| 406 |
+
char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + ell_offset;
|
| 407 |
+
|
| 408 |
+
AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
|
| 409 |
+
|
| 410 |
+
bool is_valid = address_iterator_.valid();
|
| 411 |
+
is_valid = is_valid && (ell_offset >= 0);
|
| 412 |
+
|
| 413 |
+
cutlass::arch::global_load<AccessType,
|
| 414 |
+
sizeof(AccessType)
|
| 415 |
+
>(
|
| 416 |
+
frag_ptr[idx], access_ptr, is_valid);
|
| 417 |
+
|
| 418 |
+
++address_iterator_;
|
| 419 |
+
}
|
| 420 |
+
}
|
| 421 |
+
}
|
| 422 |
+
}
|
| 423 |
+
/// Store a fragment to memory
|
| 424 |
+
CUTLASS_DEVICE
|
| 425 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 426 |
+
store_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
/// Store a fragment to memory
|
| 430 |
+
CUTLASS_DEVICE
|
| 431 |
+
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
| 432 |
+
address_iterator_.set_iteration_index(0);
|
| 433 |
+
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
| 434 |
+
|
| 435 |
+
CUTLASS_PRAGMA_UNROLL
|
| 436 |
+
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
| 437 |
+
CUTLASS_PRAGMA_UNROLL
|
| 438 |
+
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
| 439 |
+
CUTLASS_PRAGMA_UNROLL
|
| 440 |
+
for (int v = 0; v < kAccessesPerVector; ++v) {
|
| 441 |
+
|
| 442 |
+
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
| 443 |
+
|
| 444 |
+
char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
|
| 445 |
+
AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
|
| 446 |
+
|
| 447 |
+
if (address_iterator_.valid()) {
|
| 448 |
+
*access_ptr = frag_ptr[idx];
|
| 449 |
+
}
|
| 450 |
+
++address_iterator_;
|
| 451 |
+
}
|
| 452 |
+
}
|
| 453 |
+
}
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
/// Store a fragment to memory
|
| 457 |
+
CUTLASS_DEVICE
|
| 458 |
+
void store(Fragment const &frag) { store_with_byte_offset(frag, 0); }
|
| 459 |
+
};
|
| 460 |
+
|
| 461 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 462 |
+
|
| 463 |
+
/// Specialization of EllPredicatedTileIterator for pitch-linear data.
|
| 464 |
+
///
|
| 465 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 466 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 467 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 468 |
+
/// MaskedTileIteratorConcept
|
| 469 |
+
///
|
| 470 |
+
template <
|
| 471 |
+
typename Shape_,
|
| 472 |
+
typename Element_,
|
| 473 |
+
int AdvanceRank,
|
| 474 |
+
typename ThreadMap_,
|
| 475 |
+
int AccessSize
|
| 476 |
+
>
|
| 477 |
+
class EllPredicatedTileIterator<Shape_, Element_, layout::ColumnMajor, AdvanceRank, ThreadMap_, AccessSize> {
|
| 478 |
+
public:
|
| 479 |
+
|
| 480 |
+
static_assert(AdvanceRank == 0 || AdvanceRank == 1,
|
| 481 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 482 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 483 |
+
|
| 484 |
+
using Shape = Shape_;
|
| 485 |
+
using Element = Element_;
|
| 486 |
+
using Layout = layout::ColumnMajor;
|
| 487 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 488 |
+
using ThreadMap = ThreadMap_;
|
| 489 |
+
|
| 490 |
+
using Index = typename Layout::Index;
|
| 491 |
+
using LongIndex = typename Layout::LongIndex;
|
| 492 |
+
|
| 493 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 494 |
+
using TensorView = TensorView<Element, Layout>;
|
| 495 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 496 |
+
|
| 497 |
+
using Pointer = Element *;
|
| 498 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 499 |
+
|
| 500 |
+
using UnderlyingIterator = EllPredicatedTileIterator<
|
| 501 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
|
| 502 |
+
Element,
|
| 503 |
+
layout::PitchLinear,
|
| 504 |
+
(kAdvanceRank == 0 ? 0 : 1),
|
| 505 |
+
ThreadMap,
|
| 506 |
+
AccessSize
|
| 507 |
+
>;
|
| 508 |
+
|
| 509 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 510 |
+
|
| 511 |
+
/// Fragment object to be loaded or stored
|
| 512 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
| 513 |
+
|
| 514 |
+
/// Predicate vector stores mask to guard accesses
|
| 515 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 516 |
+
|
| 517 |
+
/// Iterator for ELL storage
|
| 518 |
+
using EllIterator = typename cutlass::transform::threadblock::ell::Iterator;
|
| 519 |
+
|
| 520 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 521 |
+
class Params {
|
| 522 |
+
private:
|
| 523 |
+
|
| 524 |
+
friend EllPredicatedTileIterator;
|
| 525 |
+
|
| 526 |
+
/// Parameters object
|
| 527 |
+
typename UnderlyingIterator::Params params_;
|
| 528 |
+
|
| 529 |
+
public:
|
| 530 |
+
|
| 531 |
+
CUTLASS_HOST_DEVICE
|
| 532 |
+
Params() { }
|
| 533 |
+
|
| 534 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 535 |
+
CUTLASS_HOST_DEVICE
|
| 536 |
+
Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {
|
| 537 |
+
|
| 538 |
+
}
|
| 539 |
+
};
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
private:
|
| 543 |
+
|
| 544 |
+
//
|
| 545 |
+
// Data members
|
| 546 |
+
//
|
| 547 |
+
|
| 548 |
+
/// Underlying pitch-linear tile iterator
|
| 549 |
+
UnderlyingIterator iterator_;
|
| 550 |
+
|
| 551 |
+
public:
|
| 552 |
+
|
| 553 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
|
| 554 |
+
CUTLASS_HOST_DEVICE
|
| 555 |
+
EllPredicatedTileIterator(
|
| 556 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 557 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 558 |
+
TensorCoord extent, ///< Extent of tensor
|
| 559 |
+
int thread_id, ///< ID of each participating thread
|
| 560 |
+
TensorCoord const &threadblock_offset ///< Initial offset of threadblock
|
| 561 |
+
):
|
| 562 |
+
iterator_(
|
| 563 |
+
params.params_,
|
| 564 |
+
pointer,
|
| 565 |
+
layout::PitchLinearCoord(extent.row(), extent.column()),
|
| 566 |
+
thread_id,
|
| 567 |
+
layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column())
|
| 568 |
+
) { }
|
| 569 |
+
|
| 570 |
+
/// Construct a EllPredicatedTileIterator with zero threadblock offset
|
| 571 |
+
CUTLASS_HOST_DEVICE
|
| 572 |
+
EllPredicatedTileIterator(
|
| 573 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 574 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 575 |
+
TensorCoord extent, ///< Extent of tensor
|
| 576 |
+
int thread_id ///< ID of each participating thread
|
| 577 |
+
): EllPredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
|
| 578 |
+
|
| 579 |
+
/// Adds a pointer offset in units of Element
|
| 580 |
+
CUTLASS_HOST_DEVICE
|
| 581 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 582 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 583 |
+
}
|
| 584 |
+
|
| 585 |
+
/// Advances to the next tile in memory.
|
| 586 |
+
///
|
| 587 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 588 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 589 |
+
/// are lightweight and must only update the internal pointer.
|
| 590 |
+
CUTLASS_HOST_DEVICE
|
| 591 |
+
EllPredicatedTileIterator &operator++() {
|
| 592 |
+
++iterator_;
|
| 593 |
+
return *this;
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
/// Advances to the next tile in memory.
|
| 597 |
+
///
|
| 598 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 599 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 600 |
+
/// are lightweight and must only update the internal pointer.
|
| 601 |
+
CUTLASS_HOST_DEVICE
|
| 602 |
+
EllPredicatedTileIterator operator++(int) {
|
| 603 |
+
EllPredicatedTileIterator self(*this);
|
| 604 |
+
operator++();
|
| 605 |
+
return self;
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
/// Returns a stride
|
| 609 |
+
CUTLASS_HOST_DEVICE
|
| 610 |
+
int get_stride() const { return iterator_.get_stride(); }
|
| 611 |
+
|
| 612 |
+
/// Clears the predicate set efficiently
|
| 613 |
+
CUTLASS_HOST_DEVICE
|
| 614 |
+
void clear_mask(bool enable = true) {
|
| 615 |
+
iterator_.clear_mask(enable);
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
/// Clears the predicate set efficiently
|
| 619 |
+
CUTLASS_HOST_DEVICE
|
| 620 |
+
void enable_mask() {
|
| 621 |
+
iterator_.enable_mask();
|
| 622 |
+
}
|
| 623 |
+
|
| 624 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 625 |
+
CUTLASS_HOST_DEVICE
|
| 626 |
+
void set_mask(Mask const &mask) {
|
| 627 |
+
iterator_.set_mask(mask);
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
/// Gets the mask
|
| 631 |
+
CUTLASS_HOST_DEVICE
|
| 632 |
+
void get_mask(Mask &mask) {
|
| 633 |
+
iterator_.get_mask(mask);
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
/// add mask for small tiles in ELL
|
| 637 |
+
CUTLASS_HOST_DEVICE
|
| 638 |
+
void ell_add_mask(int blocksize) {
|
| 639 |
+
iterator_.ell_add_mask(blocksize);
|
| 640 |
+
}
|
| 641 |
+
|
| 642 |
+
/// Loads a fragment from memory
|
| 643 |
+
CUTLASS_DEVICE
|
| 644 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 645 |
+
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
/// Loads a fragment from memory
|
| 649 |
+
CUTLASS_DEVICE
|
| 650 |
+
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
| 651 |
+
iterator_.load_with_byte_offset(frag, byte_offset);
|
| 652 |
+
}
|
| 653 |
+
|
| 654 |
+
/// Loads a fragment from memory
|
| 655 |
+
CUTLASS_DEVICE
|
| 656 |
+
void load(Fragment &frag) {
|
| 657 |
+
load_with_pointer_offset(frag, 0);
|
| 658 |
+
}
|
| 659 |
+
|
| 660 |
+
CUTLASS_DEVICE
|
| 661 |
+
void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) {
|
| 662 |
+
iterator_.load_with_ell_index(frag, ell_iter);
|
| 663 |
+
}
|
| 664 |
+
|
| 665 |
+
CUTLASS_DEVICE
|
| 666 |
+
void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) {
|
| 667 |
+
iterator_.load_with_ell_index_fast(frag, ell_iter);
|
| 668 |
+
}
|
| 669 |
+
|
| 670 |
+
/// Store a fragment to memory
|
| 671 |
+
CUTLASS_DEVICE
|
| 672 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 673 |
+
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 674 |
+
}
|
| 675 |
+
|
| 676 |
+
/// Store a fragment to memory
|
| 677 |
+
CUTLASS_DEVICE
|
| 678 |
+
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
| 679 |
+
iterator_.store_with_byte_offset(frag, byte_offset);
|
| 680 |
+
}
|
| 681 |
+
|
| 682 |
+
/// Store a fragment to memory
|
| 683 |
+
CUTLASS_DEVICE
|
| 684 |
+
void store(Fragment const &frag) {
|
| 685 |
+
store_with_pointer_offset(frag, 0);
|
| 686 |
+
}
|
| 687 |
+
};
|
| 688 |
+
|
| 689 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 690 |
+
|
| 691 |
+
/// Specialization of EllPredicatedTileIterator for pitch-linear data.
|
| 692 |
+
///
|
| 693 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 694 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 695 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 696 |
+
/// MaskedTileIteratorConcept
|
| 697 |
+
///
|
| 698 |
+
template <
|
| 699 |
+
typename Shape_,
|
| 700 |
+
typename Element_,
|
| 701 |
+
int AdvanceRank,
|
| 702 |
+
typename ThreadMap_,
|
| 703 |
+
int AccessSize
|
| 704 |
+
>
|
| 705 |
+
class EllPredicatedTileIterator<Shape_, Element_, layout::RowMajor, AdvanceRank, ThreadMap_, AccessSize> {
|
| 706 |
+
public:
|
| 707 |
+
|
| 708 |
+
static_assert(AdvanceRank == 0 || AdvanceRank == 1,
|
| 709 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 710 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 711 |
+
|
| 712 |
+
using Shape = Shape_;
|
| 713 |
+
using Element = Element_;
|
| 714 |
+
using Layout = layout::RowMajor;
|
| 715 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 716 |
+
using ThreadMap = ThreadMap_;
|
| 717 |
+
|
| 718 |
+
using Index = typename Layout::Index;
|
| 719 |
+
using LongIndex = typename Layout::LongIndex;
|
| 720 |
+
|
| 721 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 722 |
+
using TensorView = TensorView<Element, Layout>;
|
| 723 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 724 |
+
|
| 725 |
+
using Pointer = Element *;
|
| 726 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 727 |
+
|
| 728 |
+
using UnderlyingIterator = EllPredicatedTileIterator<
|
| 729 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
|
| 730 |
+
Element,
|
| 731 |
+
layout::PitchLinear,
|
| 732 |
+
(kAdvanceRank == 0 ? 1 : 0),
|
| 733 |
+
ThreadMap,
|
| 734 |
+
AccessSize
|
| 735 |
+
>;
|
| 736 |
+
|
| 737 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 738 |
+
|
| 739 |
+
/// Fragment object to be loaded or stored
|
| 740 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
| 741 |
+
|
| 742 |
+
/// Predicate vector stores mask to guard accesses
|
| 743 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 744 |
+
|
| 745 |
+
/// Iterator for ELL storage
|
| 746 |
+
using EllIterator = typename cutlass::transform::threadblock::ell::Iterator;
|
| 747 |
+
|
| 748 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 749 |
+
class Params {
|
| 750 |
+
private:
|
| 751 |
+
|
| 752 |
+
friend EllPredicatedTileIterator;
|
| 753 |
+
|
| 754 |
+
/// Parameters object
|
| 755 |
+
typename UnderlyingIterator::Params params_;
|
| 756 |
+
|
| 757 |
+
public:
|
| 758 |
+
|
| 759 |
+
CUTLASS_HOST_DEVICE
|
| 760 |
+
Params() { }
|
| 761 |
+
|
| 762 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 763 |
+
CUTLASS_HOST_DEVICE
|
| 764 |
+
Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {
|
| 765 |
+
|
| 766 |
+
};
|
| 767 |
+
};
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
private:
|
| 771 |
+
|
| 772 |
+
//
|
| 773 |
+
// Data members
|
| 774 |
+
//
|
| 775 |
+
|
| 776 |
+
/// Underlying pitch-linear tile iterator
|
| 777 |
+
UnderlyingIterator iterator_;
|
| 778 |
+
|
| 779 |
+
public:
|
| 780 |
+
|
| 781 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
|
| 782 |
+
CUTLASS_HOST_DEVICE
|
| 783 |
+
EllPredicatedTileIterator(
|
| 784 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 785 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 786 |
+
TensorCoord extent, ///< Extent of tensor
|
| 787 |
+
int thread_id, ///< ID of each participating thread
|
| 788 |
+
TensorCoord const &threadblock_offset ///< Initial offset of threadblock
|
| 789 |
+
):
|
| 790 |
+
iterator_(
|
| 791 |
+
params.params_,
|
| 792 |
+
pointer,
|
| 793 |
+
layout::PitchLinearCoord(extent.column(), extent.row()),
|
| 794 |
+
thread_id,
|
| 795 |
+
layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row())
|
| 796 |
+
) { }
|
| 797 |
+
|
| 798 |
+
/// Construct a EllPredicatedTileIterator with zero threadblock offset
|
| 799 |
+
CUTLASS_HOST_DEVICE
|
| 800 |
+
EllPredicatedTileIterator(
|
| 801 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 802 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 803 |
+
TensorCoord extent, ///< Extent of tensor
|
| 804 |
+
int thread_id ///< ID of each participating thread
|
| 805 |
+
): EllPredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
|
| 806 |
+
|
| 807 |
+
/// Adds a pointer offset in units of Element
|
| 808 |
+
CUTLASS_HOST_DEVICE
|
| 809 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 810 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 811 |
+
}
|
| 812 |
+
|
| 813 |
+
/// Advances to the next tile in memory.
|
| 814 |
+
///
|
| 815 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 816 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 817 |
+
/// are lightweight and must only update the internal pointer.
|
| 818 |
+
CUTLASS_HOST_DEVICE
|
| 819 |
+
EllPredicatedTileIterator &operator++() {
|
| 820 |
+
++iterator_;
|
| 821 |
+
return *this;
|
| 822 |
+
}
|
| 823 |
+
|
| 824 |
+
/// Advances to the next tile in memory.
|
| 825 |
+
///
|
| 826 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 827 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 828 |
+
/// are lightweight and must only update the internal pointer.
|
| 829 |
+
CUTLASS_HOST_DEVICE
|
| 830 |
+
EllPredicatedTileIterator operator++(int) {
|
| 831 |
+
EllPredicatedTileIterator self(*this);
|
| 832 |
+
operator++();
|
| 833 |
+
return self;
|
| 834 |
+
}
|
| 835 |
+
|
| 836 |
+
/// Returns a stride
|
| 837 |
+
CUTLASS_HOST_DEVICE
|
| 838 |
+
int get_stride() const { return iterator_.get_stride(); }
|
| 839 |
+
|
| 840 |
+
/// Clears the predicate set efficiently
|
| 841 |
+
CUTLASS_HOST_DEVICE
|
| 842 |
+
void clear_mask(bool enable = true) {
|
| 843 |
+
iterator_.clear_mask(enable);
|
| 844 |
+
}
|
| 845 |
+
|
| 846 |
+
/// Clears the predicate set efficiently
|
| 847 |
+
CUTLASS_HOST_DEVICE
|
| 848 |
+
void enable_mask() {
|
| 849 |
+
iterator_.enable_mask();
|
| 850 |
+
}
|
| 851 |
+
|
| 852 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 853 |
+
CUTLASS_HOST_DEVICE
|
| 854 |
+
void set_mask(Mask const &mask) {
|
| 855 |
+
iterator_.set_mask(mask);
|
| 856 |
+
}
|
| 857 |
+
|
| 858 |
+
/// Gets the mask
|
| 859 |
+
CUTLASS_HOST_DEVICE
|
| 860 |
+
void get_mask(Mask &mask) {
|
| 861 |
+
iterator_.get_mask(mask);
|
| 862 |
+
}
|
| 863 |
+
|
| 864 |
+
/// add mask for small tiles in ELL
|
| 865 |
+
CUTLASS_HOST_DEVICE
|
| 866 |
+
void ell_add_mask(int blocksize) {
|
| 867 |
+
iterator_.ell_add_mask(blocksize);
|
| 868 |
+
}
|
| 869 |
+
|
| 870 |
+
/// Loads a fragment from memory
|
| 871 |
+
CUTLASS_DEVICE
|
| 872 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 873 |
+
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 874 |
+
}
|
| 875 |
+
|
| 876 |
+
/// Loads a fragment from memory
|
| 877 |
+
CUTLASS_DEVICE
|
| 878 |
+
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
| 879 |
+
iterator_.load_with_byte_offset(frag, byte_offset);
|
| 880 |
+
}
|
| 881 |
+
|
| 882 |
+
/// Loads a fragment from memory
|
| 883 |
+
CUTLASS_DEVICE
|
| 884 |
+
void load(Fragment &frag) {
|
| 885 |
+
load_with_pointer_offset(frag, 0);
|
| 886 |
+
}
|
| 887 |
+
|
| 888 |
+
CUTLASS_DEVICE
|
| 889 |
+
void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) {
|
| 890 |
+
iterator_.load_with_ell_index(frag, ell_iter);
|
| 891 |
+
}
|
| 892 |
+
|
| 893 |
+
CUTLASS_DEVICE
|
| 894 |
+
void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) {
|
| 895 |
+
iterator_.load_with_ell_index_fast(frag, ell_iter);
|
| 896 |
+
}
|
| 897 |
+
|
| 898 |
+
/// Store a fragment to memory
|
| 899 |
+
CUTLASS_DEVICE
|
| 900 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 901 |
+
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 902 |
+
}
|
| 903 |
+
|
| 904 |
+
/// Store a fragment to memory
|
| 905 |
+
CUTLASS_DEVICE
|
| 906 |
+
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
| 907 |
+
iterator_.store_with_byte_offset(frag, byte_offset);
|
| 908 |
+
}
|
| 909 |
+
|
| 910 |
+
/// Store a fragment to memory
|
| 911 |
+
CUTLASS_DEVICE
|
| 912 |
+
void store(Fragment const &frag) {
|
| 913 |
+
store_with_pointer_offset(frag, 0);
|
| 914 |
+
}
|
| 915 |
+
};
|
| 916 |
+
|
| 917 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 918 |
+
|
| 919 |
+
/// Specialization of EllPredicatedTileIterator for interleaved data. It is mapped
|
| 920 |
+
/// to the congruous layout.
|
| 921 |
+
///
|
| 922 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 923 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 924 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 925 |
+
/// MaskedTileIteratorConcept
|
| 926 |
+
///
|
| 927 |
+
|
| 928 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 929 |
+
typename ThreadMap_, int AccessSize, int InterleavedK>
|
| 930 |
+
class EllPredicatedTileIterator<Shape_, Element_,
|
| 931 |
+
layout::ColumnMajorInterleaved<InterleavedK>,
|
| 932 |
+
AdvanceRank, ThreadMap_, AccessSize> {
|
| 933 |
+
public:
|
| 934 |
+
static_assert(
|
| 935 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 936 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 937 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 938 |
+
|
| 939 |
+
using Shape = Shape_;
|
| 940 |
+
using Element = Element_;
|
| 941 |
+
static int const kInterleavedK = InterleavedK;
|
| 942 |
+
using Layout = layout::ColumnMajorInterleaved<kInterleavedK>;
|
| 943 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 944 |
+
using ThreadMap = ThreadMap_;
|
| 945 |
+
|
| 946 |
+
using Index = typename Layout::Index;
|
| 947 |
+
using LongIndex = typename Layout::LongIndex;
|
| 948 |
+
|
| 949 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 950 |
+
using TensorView = TensorView<Element, Layout>;
|
| 951 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 952 |
+
|
| 953 |
+
using Pointer = Element *;
|
| 954 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 955 |
+
|
| 956 |
+
using UnderlyingIterator = EllPredicatedTileIterator<
|
| 957 |
+
layout::PitchLinearShape<Shape::kRow * kInterleavedK,
|
| 958 |
+
Shape::kColumn / kInterleavedK>,
|
| 959 |
+
Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessSize>;
|
| 960 |
+
|
| 961 |
+
|
| 962 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 963 |
+
|
| 964 |
+
/// Fragment object to be loaded or stored
|
| 965 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
|
| 966 |
+
ThreadMap::kElementsPerAccess>;
|
| 967 |
+
|
| 968 |
+
/// Predicate vector stores mask to guard accesses
|
| 969 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 970 |
+
|
| 971 |
+
/// Iterator for ELL storage
|
| 972 |
+
using EllIterator = typename cutlass::transform::threadblock::ell::Iterator;
|
| 973 |
+
|
| 974 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 975 |
+
class Params {
|
| 976 |
+
private:
|
| 977 |
+
friend EllPredicatedTileIterator;
|
| 978 |
+
|
| 979 |
+
/// Parameters object
|
| 980 |
+
typename UnderlyingIterator::Params params_;
|
| 981 |
+
|
| 982 |
+
public:
|
| 983 |
+
CUTLASS_HOST_DEVICE
|
| 984 |
+
Params() {}
|
| 985 |
+
|
| 986 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 987 |
+
CUTLASS_HOST_DEVICE
|
| 988 |
+
Params(Layout const &layout)
|
| 989 |
+
: params_(layout::PitchLinear(layout.stride(0))) {}
|
| 990 |
+
};
|
| 991 |
+
|
| 992 |
+
private:
|
| 993 |
+
//
|
| 994 |
+
// Data members
|
| 995 |
+
//
|
| 996 |
+
|
| 997 |
+
/// Underlying pitch-linear tile iterator
|
| 998 |
+
UnderlyingIterator iterator_;
|
| 999 |
+
|
| 1000 |
+
public:
|
| 1001 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 1002 |
+
/// and thread ID
|
| 1003 |
+
CUTLASS_HOST_DEVICE
|
| 1004 |
+
EllPredicatedTileIterator(
|
| 1005 |
+
/// Precomputed parameters object
|
| 1006 |
+
Params const ¶ms,
|
| 1007 |
+
/// Pointer to start of tensor
|
| 1008 |
+
Pointer pointer,
|
| 1009 |
+
/// Extent of tensor
|
| 1010 |
+
TensorCoord extent,
|
| 1011 |
+
/// ID of each participating thread
|
| 1012 |
+
int thread_id,
|
| 1013 |
+
/// Initial offset of threadblock
|
| 1014 |
+
TensorCoord const &threadblock_offset)
|
| 1015 |
+
: iterator_(params.params_, pointer,
|
| 1016 |
+
layout::PitchLinearCoord(extent.row() * kInterleavedK,
|
| 1017 |
+
extent.column() / kInterleavedK),
|
| 1018 |
+
thread_id,
|
| 1019 |
+
layout::PitchLinearCoord(
|
| 1020 |
+
threadblock_offset.row() * kInterleavedK,
|
| 1021 |
+
threadblock_offset.column() / kInterleavedK)) {}
|
| 1022 |
+
|
| 1023 |
+
/// Construct a EllPredicatedTileIterator with zero threadblock offset
|
| 1024 |
+
CUTLASS_HOST_DEVICE
|
| 1025 |
+
EllPredicatedTileIterator(
|
| 1026 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 1027 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 1028 |
+
TensorCoord extent, ///< Extent of tensor
|
| 1029 |
+
int thread_id ///< ID of each participating thread
|
| 1030 |
+
)
|
| 1031 |
+
: EllPredicatedTileIterator(params, pointer, extent, thread_id,
|
| 1032 |
+
make_Coord(0, 0)) {}
|
| 1033 |
+
|
| 1034 |
+
/// Adds a pointer offset in units of Element
|
| 1035 |
+
CUTLASS_HOST_DEVICE
|
| 1036 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1037 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 1038 |
+
}
|
| 1039 |
+
|
| 1040 |
+
/// Advances to the next tile in memory.
|
| 1041 |
+
///
|
| 1042 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1043 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1044 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1045 |
+
/// pointer.
|
| 1046 |
+
CUTLASS_HOST_DEVICE
|
| 1047 |
+
EllPredicatedTileIterator &operator++() {
|
| 1048 |
+
++iterator_;
|
| 1049 |
+
return *this;
|
| 1050 |
+
}
|
| 1051 |
+
|
| 1052 |
+
/// Advances to the next tile in memory.
|
| 1053 |
+
///
|
| 1054 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1055 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1056 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1057 |
+
/// pointer.
|
| 1058 |
+
CUTLASS_HOST_DEVICE
|
| 1059 |
+
EllPredicatedTileIterator operator++(int) {
|
| 1060 |
+
EllPredicatedTileIterator self(*this);
|
| 1061 |
+
operator++();
|
| 1062 |
+
return self;
|
| 1063 |
+
}
|
| 1064 |
+
|
| 1065 |
+
/// Returns a stride
|
| 1066 |
+
CUTLASS_HOST_DEVICE
|
| 1067 |
+
int get_stride() const { return iterator_.get_stride(); }
|
| 1068 |
+
|
| 1069 |
+
/// Clears the predicate set efficiently
|
| 1070 |
+
CUTLASS_HOST_DEVICE
|
| 1071 |
+
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
| 1072 |
+
|
| 1073 |
+
/// Clears the predicate set efficiently
|
| 1074 |
+
CUTLASS_HOST_DEVICE
|
| 1075 |
+
void enable_mask() { iterator_.enable_mask(); }
|
| 1076 |
+
|
| 1077 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 1078 |
+
CUTLASS_HOST_DEVICE
|
| 1079 |
+
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
| 1080 |
+
|
| 1081 |
+
/// Gets the mask
|
| 1082 |
+
CUTLASS_HOST_DEVICE
|
| 1083 |
+
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
| 1084 |
+
|
| 1085 |
+
/// add mask for small tiles in ELL
|
| 1086 |
+
CUTLASS_HOST_DEVICE
|
| 1087 |
+
void ell_add_mask(int blocksize) { iterator_.ell_add_mask(blocksize); }
|
| 1088 |
+
|
| 1089 |
+
/// Loads a fragment from memory
|
| 1090 |
+
CUTLASS_DEVICE
|
| 1091 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 1092 |
+
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 1093 |
+
}
|
| 1094 |
+
|
| 1095 |
+
CUTLASS_DEVICE
|
| 1096 |
+
void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) {
|
| 1097 |
+
iterator_.load_with_ell_index(frag, ell_iter);
|
| 1098 |
+
}
|
| 1099 |
+
|
| 1100 |
+
CUTLASS_DEVICE
|
| 1101 |
+
void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) {
|
| 1102 |
+
iterator_.load_with_ell_index_fast(frag, ell_iter);
|
| 1103 |
+
}
|
| 1104 |
+
|
| 1105 |
+
/// Loads a fragment from memory
|
| 1106 |
+
CUTLASS_DEVICE
|
| 1107 |
+
void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
|
| 1108 |
+
|
| 1109 |
+
/// Store a fragment to memory
|
| 1110 |
+
CUTLASS_DEVICE
|
| 1111 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 1112 |
+
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 1113 |
+
}
|
| 1114 |
+
|
| 1115 |
+
/// Store a fragment to memory
|
| 1116 |
+
CUTLASS_DEVICE
|
| 1117 |
+
void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
|
| 1118 |
+
};
|
| 1119 |
+
|
| 1120 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1121 |
+
|
| 1122 |
+
/// Specialization of EllPredicatedTileIterator for interleaved-32 data. It is
|
| 1123 |
+
/// mapped to the congruous layout.
|
| 1124 |
+
///
|
| 1125 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 1126 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 1127 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 1128 |
+
/// MaskedTileIteratorConcept
|
| 1129 |
+
///
|
| 1130 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 1131 |
+
typename ThreadMap_, int AccessSize, int InterleavedK>
|
| 1132 |
+
class EllPredicatedTileIterator<Shape_, Element_,
|
| 1133 |
+
layout::RowMajorInterleaved<InterleavedK>,
|
| 1134 |
+
AdvanceRank, ThreadMap_, AccessSize> {
|
| 1135 |
+
public:
|
| 1136 |
+
static_assert(
|
| 1137 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 1138 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 1139 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 1140 |
+
|
| 1141 |
+
using Shape = Shape_;
|
| 1142 |
+
using Element = Element_;
|
| 1143 |
+
static int const kInterleavedK = InterleavedK;
|
| 1144 |
+
using Layout = layout::RowMajorInterleaved<kInterleavedK>;
|
| 1145 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 1146 |
+
using ThreadMap = ThreadMap_;
|
| 1147 |
+
|
| 1148 |
+
using Index = typename Layout::Index;
|
| 1149 |
+
using LongIndex = typename Layout::LongIndex;
|
| 1150 |
+
|
| 1151 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 1152 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1153 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 1154 |
+
|
| 1155 |
+
using Pointer = Element *;
|
| 1156 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 1157 |
+
|
| 1158 |
+
using UnderlyingIterator = EllPredicatedTileIterator<
|
| 1159 |
+
layout::PitchLinearShape<Shape::kColumn * kInterleavedK,
|
| 1160 |
+
Shape::kRow / kInterleavedK>,
|
| 1161 |
+
Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessSize>;
|
| 1162 |
+
|
| 1163 |
+
|
| 1164 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 1165 |
+
|
| 1166 |
+
/// Fragment object to be loaded or stored
|
| 1167 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
|
| 1168 |
+
ThreadMap::kElementsPerAccess>;
|
| 1169 |
+
|
| 1170 |
+
/// Predicate vector stores mask to guard accesses
|
| 1171 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 1172 |
+
|
| 1173 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 1174 |
+
class Params {
|
| 1175 |
+
private:
|
| 1176 |
+
friend EllPredicatedTileIterator;
|
| 1177 |
+
|
| 1178 |
+
/// Parameters object
|
| 1179 |
+
typename UnderlyingIterator::Params params_;
|
| 1180 |
+
|
| 1181 |
+
public:
|
| 1182 |
+
CUTLASS_HOST_DEVICE
|
| 1183 |
+
Params() {}
|
| 1184 |
+
|
| 1185 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 1186 |
+
CUTLASS_HOST_DEVICE
|
| 1187 |
+
Params(Layout const &layout)
|
| 1188 |
+
: params_(layout::PitchLinear(layout.stride(0))) {}
|
| 1189 |
+
};
|
| 1190 |
+
|
| 1191 |
+
private:
|
| 1192 |
+
//
|
| 1193 |
+
// Data members
|
| 1194 |
+
//
|
| 1195 |
+
|
| 1196 |
+
/// Underlying pitch-linear tile iterator
|
| 1197 |
+
UnderlyingIterator iterator_;
|
| 1198 |
+
|
| 1199 |
+
public:
|
| 1200 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 1201 |
+
/// and thread ID
|
| 1202 |
+
CUTLASS_HOST_DEVICE
|
| 1203 |
+
EllPredicatedTileIterator(
|
| 1204 |
+
/// Precomputed parameters object
|
| 1205 |
+
Params const ¶ms,
|
| 1206 |
+
/// Pointer to start of tensor
|
| 1207 |
+
Pointer pointer,
|
| 1208 |
+
/// Extent of tensor
|
| 1209 |
+
TensorCoord extent,
|
| 1210 |
+
/// ID of each participating thread
|
| 1211 |
+
int thread_id,
|
| 1212 |
+
/// Initial offset of threadblock
|
| 1213 |
+
TensorCoord const &threadblock_offset)
|
| 1214 |
+
: iterator_(params.params_, pointer,
|
| 1215 |
+
layout::PitchLinearCoord(extent.column() * kInterleavedK,
|
| 1216 |
+
extent.row() / kInterleavedK),
|
| 1217 |
+
thread_id,
|
| 1218 |
+
layout::PitchLinearCoord(
|
| 1219 |
+
threadblock_offset.column() * kInterleavedK,
|
| 1220 |
+
threadblock_offset.row() / kInterleavedK)) {}
|
| 1221 |
+
|
| 1222 |
+
/// Construct a EllPredicatedTileIterator with zero threadblock offset
|
| 1223 |
+
CUTLASS_HOST_DEVICE
|
| 1224 |
+
EllPredicatedTileIterator(
|
| 1225 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 1226 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 1227 |
+
TensorCoord extent, ///< Extent of tensor
|
| 1228 |
+
int thread_id ///< ID of each participating thread
|
| 1229 |
+
)
|
| 1230 |
+
: EllPredicatedTileIterator(params, pointer, extent, thread_id,
|
| 1231 |
+
make_Coord(0, 0)) {}
|
| 1232 |
+
|
| 1233 |
+
/// Adds a pointer offset in units of Element
|
| 1234 |
+
CUTLASS_HOST_DEVICE
|
| 1235 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1236 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 1237 |
+
}
|
| 1238 |
+
|
| 1239 |
+
/// Advances to the next tile in memory.
|
| 1240 |
+
///
|
| 1241 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1242 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1243 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1244 |
+
/// pointer.
|
| 1245 |
+
CUTLASS_HOST_DEVICE
|
| 1246 |
+
EllPredicatedTileIterator &operator++() {
|
| 1247 |
+
++iterator_;
|
| 1248 |
+
return *this;
|
| 1249 |
+
}
|
| 1250 |
+
|
| 1251 |
+
/// Advances to the next tile in memory.
|
| 1252 |
+
///
|
| 1253 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1254 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1255 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1256 |
+
/// pointer.
|
| 1257 |
+
CUTLASS_HOST_DEVICE
|
| 1258 |
+
EllPredicatedTileIterator operator++(int) {
|
| 1259 |
+
EllPredicatedTileIterator self(*this);
|
| 1260 |
+
operator++();
|
| 1261 |
+
return self;
|
| 1262 |
+
}
|
| 1263 |
+
|
| 1264 |
+
/// Returns a stride
|
| 1265 |
+
CUTLASS_HOST_DEVICE
|
| 1266 |
+
int get_stride() const { return iterator_.get_stride(); }
|
| 1267 |
+
|
| 1268 |
+
/// Clears the predicate set efficiently
|
| 1269 |
+
CUTLASS_HOST_DEVICE
|
| 1270 |
+
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
| 1271 |
+
|
| 1272 |
+
/// Clears the predicate set efficiently
|
| 1273 |
+
CUTLASS_HOST_DEVICE
|
| 1274 |
+
void enable_mask() { iterator_.enable_mask(); }
|
| 1275 |
+
|
| 1276 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 1277 |
+
CUTLASS_HOST_DEVICE
|
| 1278 |
+
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
| 1279 |
+
|
| 1280 |
+
/// Gets the mask
|
| 1281 |
+
CUTLASS_HOST_DEVICE
|
| 1282 |
+
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
| 1283 |
+
|
| 1284 |
+
/// add mask for small tiles in ELL
|
| 1285 |
+
CUTLASS_HOST_DEVICE
|
| 1286 |
+
void ell_add_mask(int blocksize) { iterator_.ell_add_mask(blocksize); }
|
| 1287 |
+
|
| 1288 |
+
/// Loads a fragment from memory
|
| 1289 |
+
CUTLASS_DEVICE
|
| 1290 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 1291 |
+
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 1292 |
+
}
|
| 1293 |
+
|
| 1294 |
+
/// Loads a fragment from memory
|
| 1295 |
+
CUTLASS_DEVICE
|
| 1296 |
+
void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
|
| 1297 |
+
|
| 1298 |
+
/// Store a fragment to memory
|
| 1299 |
+
CUTLASS_DEVICE
|
| 1300 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 1301 |
+
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 1302 |
+
}
|
| 1303 |
+
|
| 1304 |
+
/// Store a fragment to memory
|
| 1305 |
+
CUTLASS_DEVICE
|
| 1306 |
+
void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
|
| 1307 |
+
};
|
| 1308 |
+
|
| 1309 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1310 |
+
|
| 1311 |
+
} // namespace threadblock
|
| 1312 |
+
} // namespace transform
|
| 1313 |
+
} // namespace cutlass
|
| 1314 |
+
|
| 1315 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Templates calculating the address and predicates to the load of scale and bias vectors.
|
| 34 |
+
|
| 35 |
+
This iterator uses masks to guard out-of-bounds accesses.
|
| 36 |
+
|
| 37 |
+
It can be used to load the gamma and beta vectors of layernorm which is loop variant.
|
| 38 |
+
|
| 39 |
+
A precomputed "Params" object minimizes the amount of state that must be
|
| 40 |
+
stored in registers, and integer addition is used to advance the pointer
|
| 41 |
+
through memory.
|
| 42 |
+
*/
|
| 43 |
+
|
| 44 |
+
#pragma once
|
| 45 |
+
|
| 46 |
+
#include "cutlass/array.h"
|
| 47 |
+
#include "cutlass/coord.h"
|
| 48 |
+
#include "cutlass/cutlass.h"
|
| 49 |
+
#include "cutlass/layout/matrix.h"
|
| 50 |
+
#include "cutlass/layout/pitch_linear.h"
|
| 51 |
+
#include "cutlass/matrix_shape.h"
|
| 52 |
+
#include "cutlass/predicate_vector.h"
|
| 53 |
+
#include "cutlass/tensor_ref.h"
|
| 54 |
+
#include "cutlass/tensor_view.h"
|
| 55 |
+
#include "cutlass/conv/threadblock/conv2d_params.h"
|
| 56 |
+
|
| 57 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 58 |
+
|
| 59 |
+
namespace cutlass {
|
| 60 |
+
namespace transform {
|
| 61 |
+
namespace threadblock {
|
| 62 |
+
|
| 63 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 64 |
+
|
| 65 |
+
/// PredicatedScaleBiasVectorAccessIterator
|
| 66 |
+
///
|
| 67 |
+
template <typename ThreadblockShape,
|
| 68 |
+
typename Element,
|
| 69 |
+
typename Layout>
|
| 70 |
+
class PredicatedScaleBiasVectorAccessIterator;
|
| 71 |
+
|
| 72 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 73 |
+
|
| 74 |
+
/// Specialization of PredicatedTileAccessIterator for fprop pitch-linear data.
|
| 75 |
+
///
|
| 76 |
+
template <typename ThreadblockShape_, typename Element_>
|
| 77 |
+
class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
|
| 78 |
+
Element_,
|
| 79 |
+
layout::PitchLinear> {
|
| 80 |
+
public:
|
| 81 |
+
|
| 82 |
+
using ThreadblockShape = ThreadblockShape_;
|
| 83 |
+
using Element = Element_;
|
| 84 |
+
using Layout = layout::PitchLinear;
|
| 85 |
+
|
| 86 |
+
using Index = typename Layout::Index;
|
| 87 |
+
using LongIndex = typename Layout::LongIndex;
|
| 88 |
+
|
| 89 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 90 |
+
using TensorView = TensorView<Element, Layout>;
|
| 91 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 92 |
+
|
| 93 |
+
using ConstPointer = const Element *;
|
| 94 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 95 |
+
|
| 96 |
+
static int const kElementsPerAccess = 128 / sizeof_bits<Element>::value;
|
| 97 |
+
static int const kThreads = ThreadblockShape::kContiguous / kElementsPerAccess;
|
| 98 |
+
|
| 99 |
+
using AccessType = AlignedArray<Element, kElementsPerAccess>;
|
| 100 |
+
|
| 101 |
+
private:
|
| 102 |
+
/// Internal pointer type permits fast address arithmetic
|
| 103 |
+
using BytePointer = char *;
|
| 104 |
+
|
| 105 |
+
private:
|
| 106 |
+
//
|
| 107 |
+
// Data members
|
| 108 |
+
//
|
| 109 |
+
|
| 110 |
+
/// Internal pointer to first access of tile
|
| 111 |
+
BytePointer pointer_;
|
| 112 |
+
|
| 113 |
+
TensorCoord thread_offset_;
|
| 114 |
+
|
| 115 |
+
int problem_size_k_;
|
| 116 |
+
|
| 117 |
+
/// Used for out-of-order visitation
|
| 118 |
+
bool is_residue_tile_;
|
| 119 |
+
|
| 120 |
+
bool guard_;
|
| 121 |
+
|
| 122 |
+
TensorCoord::Index residue_size_;
|
| 123 |
+
|
| 124 |
+
public:
|
| 125 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 126 |
+
/// and thread ID
|
| 127 |
+
CUTLASS_HOST_DEVICE
|
| 128 |
+
PredicatedScaleBiasVectorAccessIterator(
|
| 129 |
+
/// Extent of tensor
|
| 130 |
+
int problem_size_k,
|
| 131 |
+
/// Pointer to the start of the scale vector
|
| 132 |
+
ConstPointer scale_pointer,
|
| 133 |
+
/// Pointer to the start of the bias vector
|
| 134 |
+
ConstPointer bias_pointer,
|
| 135 |
+
/// ID of each participating thread
|
| 136 |
+
int thread_id,
|
| 137 |
+
/// Initial offset of threadblock
|
| 138 |
+
TensorCoord const &threadblock_offset) {
|
| 139 |
+
pointer_ = (thread_id < kThreads)
|
| 140 |
+
? reinterpret_cast<BytePointer>(
|
| 141 |
+
const_cast<NonConstPointer>(scale_pointer))
|
| 142 |
+
: reinterpret_cast<BytePointer>(
|
| 143 |
+
const_cast<NonConstPointer>(bias_pointer));
|
| 144 |
+
|
| 145 |
+
// Per-thread offset in logical coordinates of tensor
|
| 146 |
+
int thread_base = (thread_id < kThreads) ? 0 : kThreads;
|
| 147 |
+
|
| 148 |
+
problem_size_k_ = problem_size_k;
|
| 149 |
+
|
| 150 |
+
is_residue_tile_ = true;
|
| 151 |
+
|
| 152 |
+
residue_size_ = (problem_size_k_ - threadblock_offset.contiguous()) % ThreadblockShape::kContiguous;
|
| 153 |
+
|
| 154 |
+
if (residue_size_ == 0) {
|
| 155 |
+
residue_size_ = ThreadblockShape::kContiguous;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
guard_ = ((thread_id - thread_base) * kElementsPerAccess) < residue_size_;
|
| 159 |
+
|
| 160 |
+
thread_offset_ =
|
| 161 |
+
threadblock_offset +
|
| 162 |
+
TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0);
|
| 163 |
+
|
| 164 |
+
set_iteration_index(0);
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
|
| 168 |
+
CUTLASS_HOST_DEVICE
|
| 169 |
+
PredicatedScaleBiasVectorAccessIterator(
|
| 170 |
+
/// Extent of tensor
|
| 171 |
+
int problem_size_k,
|
| 172 |
+
/// Pointer to start of scale vector
|
| 173 |
+
ConstPointer scale_pointer,
|
| 174 |
+
/// Pointer to start of scale vector
|
| 175 |
+
ConstPointer bias_pointer,
|
| 176 |
+
///< ID of each participating thread
|
| 177 |
+
int thread_id)
|
| 178 |
+
: PredicatedScaleBiasVectorAccessIterator(problem_size_k,
|
| 179 |
+
scale_pointer, bias_pointer,
|
| 180 |
+
thread_id, make_Coord(0, 0)) {}
|
| 181 |
+
|
| 182 |
+
/// Overrides the internal iteration index
|
| 183 |
+
CUTLASS_HOST_DEVICE
|
| 184 |
+
void set_iteration_index(int index) {}
|
| 185 |
+
|
| 186 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole threadblock tiles
|
| 187 |
+
CUTLASS_DEVICE
|
| 188 |
+
void add_tile_offset(
|
| 189 |
+
TensorCoord const &tile_offset) {
|
| 190 |
+
|
| 191 |
+
guard_ = threadIdx.x < kThreads * 2;
|
| 192 |
+
|
| 193 |
+
TensorCoord offset = is_residue_tile_ ?
|
| 194 |
+
TensorCoord(residue_size_ + ThreadblockShape::kContiguous * (tile_offset.contiguous() - 1), 0)
|
| 195 |
+
: TensorCoord(ThreadblockShape::kContiguous * tile_offset.contiguous(), 0);
|
| 196 |
+
|
| 197 |
+
thread_offset_ =
|
| 198 |
+
thread_offset_ +
|
| 199 |
+
offset;
|
| 200 |
+
|
| 201 |
+
is_residue_tile_ = false;
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
/// Returns a pointer
|
| 205 |
+
CUTLASS_HOST_DEVICE
|
| 206 |
+
AccessType *get() const {
|
| 207 |
+
|
| 208 |
+
return reinterpret_cast<AccessType *>(
|
| 209 |
+
pointer_ +
|
| 210 |
+
(thread_offset_.contiguous() * sizeof_bits<Element>::value / 8));
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
/// Increment and return an instance to self.
|
| 214 |
+
CUTLASS_HOST_DEVICE
|
| 215 |
+
PredicatedScaleBiasVectorAccessIterator &operator++() {
|
| 216 |
+
return *this;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
/// Increment and return an instance to self.
|
| 220 |
+
CUTLASS_DEVICE
|
| 221 |
+
PredicatedScaleBiasVectorAccessIterator operator++(int) {
|
| 222 |
+
PredicatedScaleBiasVectorAccessIterator self(*this);
|
| 223 |
+
operator++();
|
| 224 |
+
return self;
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
/// Clears the predicate set efficiently
|
| 228 |
+
CUTLASS_HOST_DEVICE
|
| 229 |
+
void clear_mask(bool enable = true) {
|
| 230 |
+
guard_ &= (!enable);
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
/// Returns whether access is valid or not
|
| 234 |
+
CUTLASS_HOST_DEVICE
|
| 235 |
+
bool valid() {
|
| 236 |
+
return guard_;
|
| 237 |
+
}
|
| 238 |
+
};
|
| 239 |
+
|
| 240 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 241 |
+
|
| 242 |
+
/// Specialization of PredicatedTileAccessIterator for row-major data.
|
| 243 |
+
///
|
| 244 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 245 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 246 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 247 |
+
/// MaskedTileIteratorConcept
|
| 248 |
+
///
|
| 249 |
+
template <typename ThreadblockShape_,
|
| 250 |
+
typename Element_>
|
| 251 |
+
class PredicatedScaleBiasVectorAccessIterator<ThreadblockShape_,
|
| 252 |
+
Element_,
|
| 253 |
+
layout::RowMajor> {
|
| 254 |
+
public:
|
| 255 |
+
|
| 256 |
+
using ThreadblockShape = ThreadblockShape_;
|
| 257 |
+
using Element = Element_;
|
| 258 |
+
using Layout = layout::RowMajor;
|
| 259 |
+
|
| 260 |
+
using Index = typename Layout::Index;
|
| 261 |
+
using LongIndex = typename Layout::LongIndex;
|
| 262 |
+
|
| 263 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 264 |
+
using TensorView = TensorView<Element, Layout>;
|
| 265 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 266 |
+
|
| 267 |
+
using ConstPointer = const Element *;
|
| 268 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 269 |
+
|
| 270 |
+
using UnderlyingIterator = PredicatedScaleBiasVectorAccessIterator<
|
| 271 |
+
layout::PitchLinearShape<ThreadblockShape::kColumn, ThreadblockShape::kRow>,
|
| 272 |
+
Element,
|
| 273 |
+
layout::PitchLinear>;
|
| 274 |
+
|
| 275 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 276 |
+
static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess;
|
| 277 |
+
|
| 278 |
+
private:
|
| 279 |
+
//
|
| 280 |
+
// Data members
|
| 281 |
+
//
|
| 282 |
+
|
| 283 |
+
/// Underlying pitch-linear tile iterator
|
| 284 |
+
UnderlyingIterator iterator_;
|
| 285 |
+
|
| 286 |
+
public:
|
| 287 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 288 |
+
/// and thread ID
|
| 289 |
+
CUTLASS_HOST_DEVICE
|
| 290 |
+
PredicatedScaleBiasVectorAccessIterator(
|
| 291 |
+
///< Extent of tensor
|
| 292 |
+
int problem_size_k,
|
| 293 |
+
///< Pointer to the start of the scale vector
|
| 294 |
+
ConstPointer scale_pointer,
|
| 295 |
+
///< Pointer to the start of the bias vector
|
| 296 |
+
ConstPointer bias_pointer,
|
| 297 |
+
///< ID of each participating thread
|
| 298 |
+
int thread_id,
|
| 299 |
+
///< Initial offset of threadblock
|
| 300 |
+
TensorCoord const &threadblock_offset)
|
| 301 |
+
: iterator_(problem_size_k, scale_pointer, bias_pointer,
|
| 302 |
+
thread_id,
|
| 303 |
+
layout::PitchLinearCoord(threadblock_offset.column(),
|
| 304 |
+
threadblock_offset.row())) {}
|
| 305 |
+
|
| 306 |
+
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
|
| 307 |
+
CUTLASS_HOST_DEVICE
|
| 308 |
+
PredicatedScaleBiasVectorAccessIterator(
|
| 309 |
+
int problem_size_k, ///< Extent of tensor
|
| 310 |
+
ConstPointer scale_pointer, ///< Pointer to the start of the scale vector
|
| 311 |
+
ConstPointer bias_pointer, ///< Pointer to the start of the bias vector
|
| 312 |
+
int thread_id ///< ID of each participating thread
|
| 313 |
+
)
|
| 314 |
+
: PredicatedScaleBiasVectorAccessIterator(problem_size_k,
|
| 315 |
+
scale_pointer, bias_pointer,
|
| 316 |
+
thread_id, make_Coord(0, 0)) {}
|
| 317 |
+
|
| 318 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 319 |
+
/// threadblock tiles
|
| 320 |
+
CUTLASS_HOST_DEVICE
|
| 321 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 322 |
+
iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
/// Returns a pointer
|
| 326 |
+
CUTLASS_HOST_DEVICE
|
| 327 |
+
AccessType *get() const {
|
| 328 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
/// Advances to the next tile in memory.
|
| 332 |
+
///
|
| 333 |
+
/// The first time this method is called, predicates are updated, and the
|
| 334 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 335 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 336 |
+
/// pointer.
|
| 337 |
+
CUTLASS_HOST_DEVICE
|
| 338 |
+
PredicatedScaleBiasVectorAccessIterator &operator++() {
|
| 339 |
+
++iterator_;
|
| 340 |
+
return *this;
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
/// Advances to the next tile in memory.
|
| 344 |
+
///
|
| 345 |
+
/// The first time this method is called, predicates are updated, and the
|
| 346 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 347 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 348 |
+
/// pointer.
|
| 349 |
+
CUTLASS_HOST_DEVICE
|
| 350 |
+
PredicatedScaleBiasVectorAccessIterator operator++(int) {
|
| 351 |
+
PredicatedScaleBiasVectorAccessIterator self(*this);
|
| 352 |
+
operator++();
|
| 353 |
+
return self;
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
/// Clears the predicate set efficiently
|
| 357 |
+
CUTLASS_HOST_DEVICE
|
| 358 |
+
void clear_mask(bool enable = true) {
|
| 359 |
+
iterator_.clear_mask(enable);
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
/// Returns whether access is valid or not
|
| 363 |
+
CUTLASS_HOST_DEVICE
|
| 364 |
+
bool valid() {
|
| 365 |
+
return iterator_.valid();
|
| 366 |
+
}
|
| 367 |
+
};
|
| 368 |
+
|
| 369 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 370 |
+
|
| 371 |
+
} // namespace threadblock
|
| 372 |
+
} // namespace transform
|
| 373 |
+
} // namespace cutlass
|
| 374 |
+
|
| 375 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Templates calculating the address and predicates to the load of scale and bias vectors.
|
| 34 |
+
|
| 35 |
+
This iterator uses masks to guard out-of-bounds accesses.
|
| 36 |
+
|
| 37 |
+
This can be used to load var and mean vectors in layernorm which is loop invariant.
|
| 38 |
+
|
| 39 |
+
A precomputed "Params" object minimizes the amount of state that must be
|
| 40 |
+
stored in registers, and integer addition is used to advance the pointer
|
| 41 |
+
through memory.
|
| 42 |
+
*/
|
| 43 |
+
|
| 44 |
+
#pragma once
|
| 45 |
+
|
| 46 |
+
#include "cutlass/array.h"
|
| 47 |
+
#include "cutlass/coord.h"
|
| 48 |
+
#include "cutlass/cutlass.h"
|
| 49 |
+
#include "cutlass/layout/matrix.h"
|
| 50 |
+
#include "cutlass/layout/pitch_linear.h"
|
| 51 |
+
#include "cutlass/matrix_shape.h"
|
| 52 |
+
#include "cutlass/predicate_vector.h"
|
| 53 |
+
#include "cutlass/tensor_ref.h"
|
| 54 |
+
#include "cutlass/tensor_view.h"
|
| 55 |
+
|
| 56 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 57 |
+
|
| 58 |
+
namespace cutlass {
|
| 59 |
+
namespace transform {
|
| 60 |
+
namespace threadblock {
|
| 61 |
+
|
| 62 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 63 |
+
|
| 64 |
+
/// PredicatedScaleBiasVectorIterator
|
| 65 |
+
///
|
| 66 |
+
template <typename WarpShape,
|
| 67 |
+
typename Element,
|
| 68 |
+
typename Layout>
|
| 69 |
+
class PredicatedScaleBiasVectorIterator;
|
| 70 |
+
|
| 71 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 72 |
+
|
| 73 |
+
/// Specialization of PredicatedTileIterator for wgrad pitch-linear data.
|
| 74 |
+
///
|
| 75 |
+
template <typename WarpShape_, typename Element_>
|
| 76 |
+
class PredicatedScaleBiasVectorIterator<WarpShape_,
|
| 77 |
+
Element_,
|
| 78 |
+
layout::PitchLinear> {
|
| 79 |
+
public:
|
| 80 |
+
|
| 81 |
+
using WarpShape = WarpShape_;
|
| 82 |
+
using Element = Element_;
|
| 83 |
+
using Layout = layout::PitchLinear;
|
| 84 |
+
|
| 85 |
+
using Index = typename Layout::Index;
|
| 86 |
+
using LongIndex = typename Layout::LongIndex;
|
| 87 |
+
|
| 88 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 89 |
+
using TensorView = TensorView<Element, Layout>;
|
| 90 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 91 |
+
|
| 92 |
+
using ConstPointer = const Element *;
|
| 93 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 94 |
+
|
| 95 |
+
static int const kElementsPerAccess = 1;
|
| 96 |
+
|
| 97 |
+
using AccessType = AlignedArray<Element, kElementsPerAccess>;
|
| 98 |
+
|
| 99 |
+
static int const kIterations = WarpShape::kContiguous / 8;
|
| 100 |
+
|
| 101 |
+
/// Fragment object to be loaded or stored
|
| 102 |
+
using Fragment = cutlass::Array<__half2, 2 * kIterations * kElementsPerAccess>;
|
| 103 |
+
|
| 104 |
+
private:
|
| 105 |
+
//
|
| 106 |
+
// Data members
|
| 107 |
+
//
|
| 108 |
+
|
| 109 |
+
/// Internal pointer to first access of tile
|
| 110 |
+
ConstPointer scale_pointer_;
|
| 111 |
+
ConstPointer bias_pointer_;
|
| 112 |
+
|
| 113 |
+
/// Size of tensor
|
| 114 |
+
int problem_size_;
|
| 115 |
+
|
| 116 |
+
int32_t thread_offset_;
|
| 117 |
+
|
| 118 |
+
public:
|
| 119 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 120 |
+
/// and thread ID
|
| 121 |
+
CUTLASS_HOST_DEVICE
|
| 122 |
+
PredicatedScaleBiasVectorIterator(
|
| 123 |
+
/// Extent of tensor
|
| 124 |
+
int problem_size,
|
| 125 |
+
/// Pointer to the start of the scale vector
|
| 126 |
+
ConstPointer scale_pointer,
|
| 127 |
+
/// Pointer to the start of the bias vector
|
| 128 |
+
ConstPointer bias_pointer,
|
| 129 |
+
/// ID of each participating thread
|
| 130 |
+
int thread_id,
|
| 131 |
+
/// Initial offset of threadblock
|
| 132 |
+
TensorCoord const &threadblock_offset)
|
| 133 |
+
: problem_size_(problem_size),
|
| 134 |
+
scale_pointer_(scale_pointer),
|
| 135 |
+
bias_pointer_(bias_pointer) {
|
| 136 |
+
|
| 137 |
+
thread_offset_ = threadblock_offset.contiguous() + (thread_id % 32) / 4;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
/// Construct a PredicatedTileIterator with zero threadblock offset
|
| 141 |
+
CUTLASS_HOST_DEVICE
|
| 142 |
+
PredicatedScaleBiasVectorIterator(
|
| 143 |
+
/// Extent of tensor
|
| 144 |
+
int problem_size,
|
| 145 |
+
/// Pointer to start of scale vector
|
| 146 |
+
ConstPointer scale_pointer,
|
| 147 |
+
/// Pointer to start of scale vector
|
| 148 |
+
ConstPointer bias_pointer,
|
| 149 |
+
///< ID of each participating thread
|
| 150 |
+
int thread_id)
|
| 151 |
+
: PredicatedScaleBiasVectorIterator(problem_size,
|
| 152 |
+
scale_pointer, bias_pointer,
|
| 153 |
+
thread_id, make_Coord(0, 0)) {}
|
| 154 |
+
|
| 155 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole warp tiles
|
| 156 |
+
CUTLASS_DEVICE
|
| 157 |
+
void add_tile_offset(
|
| 158 |
+
TensorCoord const &tile_offset) {
|
| 159 |
+
|
| 160 |
+
thread_offset_ += (WarpShape::kContiguous * tile_offset.contiguous());
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
/// Loads a fragment from memory
|
| 164 |
+
CUTLASS_DEVICE
|
| 165 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 166 |
+
|
| 167 |
+
frag.fill(__float2half2_rn(0.0f));
|
| 168 |
+
__half2 *frag_ptr = reinterpret_cast<__half2 *>(&frag);
|
| 169 |
+
|
| 170 |
+
// load scale
|
| 171 |
+
CUTLASS_PRAGMA_UNROLL
|
| 172 |
+
for (int c = 0; c < kIterations; ++c) {
|
| 173 |
+
|
| 174 |
+
cutlass::arch::global_load<
|
| 175 |
+
__half,
|
| 176 |
+
sizeof(AccessType)
|
| 177 |
+
>(
|
| 178 |
+
frag_ptr[c * 2].x,
|
| 179 |
+
scale_pointer_ + thread_offset_ + c * 8,
|
| 180 |
+
(thread_offset_ + c * 8) < problem_size_
|
| 181 |
+
);
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
// load bias
|
| 185 |
+
CUTLASS_PRAGMA_UNROLL
|
| 186 |
+
for (int c = 0; c < kIterations; ++c) {
|
| 187 |
+
|
| 188 |
+
cutlass::arch::global_load<
|
| 189 |
+
__half,
|
| 190 |
+
sizeof(AccessType)
|
| 191 |
+
>(
|
| 192 |
+
frag_ptr[c * 2 + 1].x,
|
| 193 |
+
bias_pointer_ + thread_offset_ + c * 8,
|
| 194 |
+
(thread_offset_ + c * 8) < problem_size_
|
| 195 |
+
);
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
// duplicate scale
|
| 199 |
+
CUTLASS_PRAGMA_UNROLL
|
| 200 |
+
for (int c = 0; c < kIterations; ++c) {
|
| 201 |
+
frag_ptr[c * 2].y = frag_ptr[c * 2].x;
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
// duplicate bias
|
| 205 |
+
CUTLASS_PRAGMA_UNROLL
|
| 206 |
+
for (int c = 0; c < kIterations; ++c) {
|
| 207 |
+
frag_ptr[c * 2 + 1].y = frag_ptr[c * 2 + 1].x;
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
/// Loads a fragment from memory
|
| 212 |
+
CUTLASS_DEVICE
|
| 213 |
+
void load(Fragment &frag) {
|
| 214 |
+
load_with_pointer_offset(frag, 0);
|
| 215 |
+
}
|
| 216 |
+
};
|
| 217 |
+
|
| 218 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 219 |
+
|
| 220 |
+
/// Specialization of PredicatedTileIterator for row-major data.
|
| 221 |
+
///
|
| 222 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 223 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 224 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 225 |
+
/// MaskedTileIteratorConcept
|
| 226 |
+
///
|
| 227 |
+
template <typename WarpShape_,
|
| 228 |
+
typename Element_>
|
| 229 |
+
class PredicatedScaleBiasVectorIterator<WarpShape_,
|
| 230 |
+
Element_,
|
| 231 |
+
layout::RowMajor> {
|
| 232 |
+
public:
|
| 233 |
+
|
| 234 |
+
using WarpShape = WarpShape_;
|
| 235 |
+
using Element = Element_;
|
| 236 |
+
using Layout = layout::RowMajor;
|
| 237 |
+
|
| 238 |
+
using Index = typename Layout::Index;
|
| 239 |
+
using LongIndex = typename Layout::LongIndex;
|
| 240 |
+
|
| 241 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 242 |
+
using TensorView = TensorView<Element, Layout>;
|
| 243 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 244 |
+
|
| 245 |
+
using ConstPointer = const Element *;
|
| 246 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 247 |
+
|
| 248 |
+
using UnderlyingIterator = PredicatedScaleBiasVectorIterator<
|
| 249 |
+
layout::PitchLinearShape<WarpShape::kColumn, WarpShape::kRow>,
|
| 250 |
+
Element,
|
| 251 |
+
layout::PitchLinear>;
|
| 252 |
+
|
| 253 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 254 |
+
static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess;
|
| 255 |
+
using Fragment = typename UnderlyingIterator::Fragment;
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
private:
|
| 259 |
+
//
|
| 260 |
+
// Data members
|
| 261 |
+
//
|
| 262 |
+
|
| 263 |
+
/// Underlying pitch-linear tile iterator
|
| 264 |
+
UnderlyingIterator iterator_;
|
| 265 |
+
|
| 266 |
+
public:
|
| 267 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 268 |
+
/// and thread ID
|
| 269 |
+
CUTLASS_HOST_DEVICE
|
| 270 |
+
PredicatedScaleBiasVectorIterator(
|
| 271 |
+
///< Extent of tensor
|
| 272 |
+
int problem_size,
|
| 273 |
+
///< Pointer to the start of the scale vector
|
| 274 |
+
ConstPointer scale_pointer,
|
| 275 |
+
///< Pointer to the start of the bias vector
|
| 276 |
+
ConstPointer bias_pointer,
|
| 277 |
+
///< ID of each participating thread
|
| 278 |
+
int thread_id,
|
| 279 |
+
///< Initial offset of threadblock
|
| 280 |
+
TensorCoord const &threadblock_offset)
|
| 281 |
+
: iterator_(problem_size, scale_pointer, bias_pointer,
|
| 282 |
+
thread_id,
|
| 283 |
+
layout::PitchLinearCoord(threadblock_offset.column(),
|
| 284 |
+
threadblock_offset.row())) {}
|
| 285 |
+
|
| 286 |
+
/// Construct a PredicatedTileIterator with zero threadblock offset
|
| 287 |
+
CUTLASS_HOST_DEVICE
|
| 288 |
+
PredicatedScaleBiasVectorIterator(
|
| 289 |
+
int problem_size, ///< Extent of tensor
|
| 290 |
+
ConstPointer scale_pointer, ///< Pointer to the start of the scale vector
|
| 291 |
+
ConstPointer bias_pointer, ///< Pointer to the start of the bias vector
|
| 292 |
+
int thread_id ///< ID of each participating thread
|
| 293 |
+
)
|
| 294 |
+
: PredicatedScaleBiasVectorIterator(problem_size,
|
| 295 |
+
scale_pointer, bias_pointer,
|
| 296 |
+
thread_id, make_Coord(0, 0)) {}
|
| 297 |
+
|
| 298 |
+
/// Overrides the internal iteration index
|
| 299 |
+
CUTLASS_HOST_DEVICE
|
| 300 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 301 |
+
|
| 302 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 303 |
+
/// threadblock tiles
|
| 304 |
+
CUTLASS_HOST_DEVICE
|
| 305 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 306 |
+
iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
/// Loads a fragment from memory
|
| 310 |
+
CUTLASS_DEVICE
|
| 311 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 312 |
+
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
/// Loads a fragment from memory
|
| 316 |
+
CUTLASS_DEVICE
|
| 317 |
+
void load(Fragment &frag) {
|
| 318 |
+
iterator_.load(frag);
|
| 319 |
+
}
|
| 320 |
+
};
|
| 321 |
+
|
| 322 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 323 |
+
|
| 324 |
+
} // namespace threadblock
|
| 325 |
+
} // namespace transform
|
| 326 |
+
} // namespace cutlass
|
| 327 |
+
|
| 328 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h
ADDED
|
@@ -0,0 +1,2118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates calculating the address and predicates to the load of tiles
|
| 33 |
+
from pitch-linear rank=2 tensors.
|
| 34 |
+
|
| 35 |
+
This iterator uses masks to guard out-of-bounds accesses. The first tile this
|
| 36 |
+
iterator visits maybe partial, then the remaining tiles are complete. So, we
|
| 37 |
+
only need to compute the predicates twice, once before the first tile and
|
| 38 |
+
once for the remaining full tiles which can share the same predicates.
|
| 39 |
+
|
| 40 |
+
A precomputed "Params" object minimizes the amount of state that must be
|
| 41 |
+
stored in registers, and integer addition is used to advance the pointer
|
| 42 |
+
through memory.
|
| 43 |
+
*/
|
| 44 |
+
|
| 45 |
+
#pragma once
|
| 46 |
+
|
| 47 |
+
#include "cutlass/array.h"
|
| 48 |
+
#include "cutlass/coord.h"
|
| 49 |
+
#include "cutlass/cutlass.h"
|
| 50 |
+
#include "cutlass/layout/matrix.h"
|
| 51 |
+
#include "cutlass/layout/permute.h"
|
| 52 |
+
#include "cutlass/layout/pitch_linear.h"
|
| 53 |
+
#include "cutlass/matrix_shape.h"
|
| 54 |
+
#include "cutlass/predicate_vector.h"
|
| 55 |
+
#include "cutlass/tensor_ref.h"
|
| 56 |
+
#include "cutlass/tensor_view.h"
|
| 57 |
+
#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h"
|
| 58 |
+
|
| 59 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 60 |
+
|
| 61 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 62 |
+
|
| 63 |
+
namespace cutlass {
|
| 64 |
+
namespace transform {
|
| 65 |
+
namespace threadblock {
|
| 66 |
+
|
| 67 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 68 |
+
|
| 69 |
+
/// PredicatedTileAccessIteratorPredicates
|
| 70 |
+
///
|
| 71 |
+
template <typename Shape_, typename Element_, typename Layout_, int AdvanceRank,
|
| 72 |
+
typename ThreadMap_, typename AccessType_>
|
| 73 |
+
class PredicatedTileAccessIteratorPredicates {
|
| 74 |
+
public:
|
| 75 |
+
using Shape = Shape_;
|
| 76 |
+
using Element = Element_;
|
| 77 |
+
using Layout = Layout_;
|
| 78 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 79 |
+
using ThreadMap = ThreadMap_;
|
| 80 |
+
using AccessType = AccessType_;
|
| 81 |
+
|
| 82 |
+
using Index = typename Layout::Index;
|
| 83 |
+
using LongIndex = typename Layout::LongIndex;
|
| 84 |
+
|
| 85 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 86 |
+
|
| 87 |
+
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
| 88 |
+
|
| 89 |
+
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
| 90 |
+
"Vectors implied by the thread map must be divisible by the access type.");
|
| 91 |
+
|
| 92 |
+
static int const kPredicatesPerByte = 4;
|
| 93 |
+
static int const kPredicatesPerWord = 4 * kPredicatesPerByte;
|
| 94 |
+
|
| 95 |
+
static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector;
|
| 96 |
+
|
| 97 |
+
/// Number of 32b words containing predicates
|
| 98 |
+
static int const kPredicateByteCount =
|
| 99 |
+
(kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte;
|
| 100 |
+
static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4;
|
| 101 |
+
|
| 102 |
+
static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u;
|
| 103 |
+
|
| 104 |
+
static_assert(kPredicateWordCount <= 4, "Too many predicates.");
|
| 105 |
+
|
| 106 |
+
/// Predicate vector stores mask to guard accesses
|
| 107 |
+
using Mask = Array<uint32_t, kPredicateWordCount>;
|
| 108 |
+
|
| 109 |
+
// private:
|
| 110 |
+
/// Guard predicates
|
| 111 |
+
uint32_t predicates_[kPredicateWordCount];
|
| 112 |
+
|
| 113 |
+
/// Size of tensor
|
| 114 |
+
TensorCoord extent_;
|
| 115 |
+
|
| 116 |
+
/// Initial offset for each thread
|
| 117 |
+
TensorCoord thread_offset_;
|
| 118 |
+
|
| 119 |
+
/// Offset to the first steady-state tile
|
| 120 |
+
TensorCoord residue_offset_;
|
| 121 |
+
|
| 122 |
+
/// Iteration along vectors implied by the thread map
|
| 123 |
+
int iteration_vector_;
|
| 124 |
+
|
| 125 |
+
/// Iteration in the contiguous dimension
|
| 126 |
+
int iteration_contiguous_;
|
| 127 |
+
|
| 128 |
+
/// Iteration in the strided dimension
|
| 129 |
+
int iteration_strided_;
|
| 130 |
+
|
| 131 |
+
public:
|
| 132 |
+
/// Computes predicates based on internally tracked per-thread offset.
|
| 133 |
+
CUTLASS_DEVICE
|
| 134 |
+
void compute_predicates_(
|
| 135 |
+
/// Extent of the matrix window
|
| 136 |
+
TensorCoord extent,
|
| 137 |
+
/// optionally, simplify predicate calculation during 'steady state' phase
|
| 138 |
+
bool is_steady_state = false) {
|
| 139 |
+
|
| 140 |
+
CUTLASS_PRAGMA_UNROLL
|
| 141 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 142 |
+
predicates_[i] = 0u;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
CUTLASS_PRAGMA_UNROLL
|
| 146 |
+
for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) {
|
| 147 |
+
|
| 148 |
+
int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
|
| 149 |
+
|
| 150 |
+
int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
|
| 151 |
+
|
| 152 |
+
int c = access_residual / kAccessesPerVector;
|
| 153 |
+
int v = access_residual % kAccessesPerVector;
|
| 154 |
+
|
| 155 |
+
TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements,
|
| 156 |
+
s * ThreadMap::Delta::kStrided);
|
| 157 |
+
|
| 158 |
+
TensorCoord coord = thread_offset_ + iteration_coord;
|
| 159 |
+
|
| 160 |
+
bool guard;
|
| 161 |
+
|
| 162 |
+
if (is_steady_state) {
|
| 163 |
+
if (kAdvanceRank == 0) {
|
| 164 |
+
guard = (coord.strided() < extent.strided());
|
| 165 |
+
} else {
|
| 166 |
+
guard = (coord.contiguous() < extent.contiguous());
|
| 167 |
+
}
|
| 168 |
+
} else {
|
| 169 |
+
guard = (coord.strided() < extent.strided() &&
|
| 170 |
+
coord.contiguous() < extent.contiguous());
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s);
|
| 174 |
+
|
| 175 |
+
int word_idx = pred_idx / kPredicatesPerWord;
|
| 176 |
+
int residual = pred_idx % kPredicatesPerWord;
|
| 177 |
+
int byte_idx = residual / kPredicatesPerByte;
|
| 178 |
+
int bit_idx = residual % kPredicatesPerByte;
|
| 179 |
+
|
| 180 |
+
predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx));
|
| 181 |
+
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
CUTLASS_HOST_DEVICE
|
| 187 |
+
void set_predicates(int thread_id, TensorCoord const &threadblock_offset) {
|
| 188 |
+
|
| 189 |
+
TensorCoord residue_extent;
|
| 190 |
+
if (kAdvanceRank) {
|
| 191 |
+
|
| 192 |
+
typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided;
|
| 193 |
+
if (!residue_size) {
|
| 194 |
+
residue_size = Shape::kStrided;
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
residue_offset_ = make_Coord(0, residue_size);
|
| 198 |
+
residue_extent = make_Coord(
|
| 199 |
+
extent_.contiguous(),
|
| 200 |
+
min(threadblock_offset.strided() + residue_size, extent_.strided())
|
| 201 |
+
);
|
| 202 |
+
} else {
|
| 203 |
+
|
| 204 |
+
typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous;
|
| 205 |
+
if (!residue_size) {
|
| 206 |
+
residue_size = Shape::kContiguous;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
residue_offset_ = make_Coord(residue_size, 0);
|
| 210 |
+
|
| 211 |
+
residue_extent = make_Coord(
|
| 212 |
+
min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size),
|
| 213 |
+
extent_.strided()
|
| 214 |
+
);
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
// Per-thread offset in logical coordinates of tensor
|
| 218 |
+
thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id);
|
| 219 |
+
|
| 220 |
+
compute_predicates_(residue_extent, false);
|
| 221 |
+
|
| 222 |
+
set_iteration_index(0);
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
/// Default constructor
|
| 226 |
+
PredicatedTileAccessIteratorPredicates() = default;
|
| 227 |
+
|
| 228 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 229 |
+
/// and thread ID
|
| 230 |
+
CUTLASS_HOST_DEVICE
|
| 231 |
+
PredicatedTileAccessIteratorPredicates(
|
| 232 |
+
/// Extent of tensor
|
| 233 |
+
TensorCoord extent)
|
| 234 |
+
: extent_(extent) {
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
/// Overrides the internal iteration index
|
| 238 |
+
CUTLASS_HOST_DEVICE
|
| 239 |
+
void set_iteration_index(int index) {
|
| 240 |
+
|
| 241 |
+
iteration_vector_ = index % kAccessesPerVector;
|
| 242 |
+
int residual_access = index / kAccessesPerVector;
|
| 243 |
+
|
| 244 |
+
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
| 245 |
+
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
| 246 |
+
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
/// Increment and return an instance to self.
|
| 250 |
+
CUTLASS_HOST_DEVICE
|
| 251 |
+
PredicatedTileAccessIteratorPredicates &operator++() {
|
| 252 |
+
|
| 253 |
+
return *this;
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
/// Clears the predicate set efficiently
|
| 257 |
+
CUTLASS_HOST_DEVICE
|
| 258 |
+
void clear_mask(bool enable = true) {
|
| 259 |
+
CUTLASS_PRAGMA_UNROLL
|
| 260 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 261 |
+
predicates_[i] = enable ? 0u : predicates_[i];
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
/// Clears the predicate set efficiently
|
| 267 |
+
CUTLASS_HOST_DEVICE
|
| 268 |
+
void enable_mask() {
|
| 269 |
+
CUTLASS_PRAGMA_UNROLL
|
| 270 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 271 |
+
predicates_[i] = 0xffffffff;
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 276 |
+
CUTLASS_HOST_DEVICE
|
| 277 |
+
void set_mask(Mask const &mask) {
|
| 278 |
+
CUTLASS_PRAGMA_UNROLL
|
| 279 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 280 |
+
predicates_[i] = mask[i];
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
/// Gets the mask
|
| 286 |
+
CUTLASS_HOST_DEVICE
|
| 287 |
+
void get_mask(Mask &mask) {
|
| 288 |
+
CUTLASS_PRAGMA_UNROLL
|
| 289 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 290 |
+
mask[i] = predicates_[i];
|
| 291 |
+
}
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
/// Returns whether access is valid or not
|
| 295 |
+
CUTLASS_HOST_DEVICE
|
| 296 |
+
bool valid() const {
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
int pred_idx =
|
| 300 |
+
iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous);
|
| 301 |
+
|
| 302 |
+
int word_idx = pred_idx / kPredicatesPerWord;
|
| 303 |
+
int residual = pred_idx % kPredicatesPerWord;
|
| 304 |
+
int byte_idx = residual / kPredicatesPerByte;
|
| 305 |
+
int bit_idx = residual % kPredicatesPerByte;
|
| 306 |
+
|
| 307 |
+
bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0;
|
| 308 |
+
return pred;
|
| 309 |
+
|
| 310 |
+
}
|
| 311 |
+
};
|
| 312 |
+
|
| 313 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 314 |
+
|
| 315 |
+
/// PredicatedTileAccessIterator
|
| 316 |
+
///
|
| 317 |
+
template <typename Shape, typename Element, typename Layout, int AdvanceRank,
|
| 318 |
+
typename ThreadMap, typename AccessType, bool Gather = false,
|
| 319 |
+
typename PermuteLayout = layout::NoPermute>
|
| 320 |
+
class PredicatedTileAccessIterator;
|
| 321 |
+
|
| 322 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 323 |
+
|
| 324 |
+
/// Specialization of PredicatedTileAccessIterator for pitch-linear data.
|
| 325 |
+
///
|
| 326 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 327 |
+
typename ThreadMap_, typename AccessType_, bool Gather,
|
| 328 |
+
typename PermuteLayout>
|
| 329 |
+
class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
|
| 330 |
+
AdvanceRank, ThreadMap_, AccessType_, Gather,
|
| 331 |
+
PermuteLayout> {
|
| 332 |
+
public:
|
| 333 |
+
static_assert(
|
| 334 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 335 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 336 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 337 |
+
|
| 338 |
+
using Shape = Shape_;
|
| 339 |
+
using Element = Element_;
|
| 340 |
+
using Layout = layout::PitchLinear;
|
| 341 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 342 |
+
using ThreadMap = ThreadMap_;
|
| 343 |
+
using AccessType = AccessType_;
|
| 344 |
+
|
| 345 |
+
using Index = typename Layout::Index;
|
| 346 |
+
using LongIndex = typename Layout::LongIndex;
|
| 347 |
+
|
| 348 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 349 |
+
using TensorView = TensorView<Element, Layout>;
|
| 350 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 351 |
+
|
| 352 |
+
using Pointer = Element *;
|
| 353 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 354 |
+
|
| 355 |
+
using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates<
|
| 356 |
+
Shape, Element, Layout, AdvanceRank, ThreadMap, AccessType>;
|
| 357 |
+
|
| 358 |
+
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
| 359 |
+
|
| 360 |
+
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
| 361 |
+
"Vectors implied by the thread map must be divisible by the access type.");
|
| 362 |
+
|
| 363 |
+
static bool constexpr Permute = !platform::is_same<PermuteLayout, layout::NoPermute>::value
|
| 364 |
+
&& !platform::is_same<PermuteLayout, layout::InversePermute<layout::NoPermute>>::value;
|
| 365 |
+
|
| 366 |
+
using Mask = typename UnderlyingPredicates::Mask;
|
| 367 |
+
|
| 368 |
+
/// Uses a non-template class
|
| 369 |
+
struct Params : PredicatedTileAccessIteratorParams {
|
| 370 |
+
|
| 371 |
+
using Base = PredicatedTileAccessIteratorParams;
|
| 372 |
+
|
| 373 |
+
/// Default constructor
|
| 374 |
+
Params() = default;
|
| 375 |
+
|
| 376 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 377 |
+
CUTLASS_HOST_DEVICE
|
| 378 |
+
Params(Layout const &layout) :
|
| 379 |
+
Base(layout.stride(0),
|
| 380 |
+
MakePredicatedTileAccessIteratorDesc<Shape, Element, Layout, kAdvanceRank, ThreadMap>()()
|
| 381 |
+
) { }
|
| 382 |
+
|
| 383 |
+
CUTLASS_HOST_DEVICE
|
| 384 |
+
Params(Base const &base) :
|
| 385 |
+
Base(base) { }
|
| 386 |
+
};
|
| 387 |
+
|
| 388 |
+
private:
|
| 389 |
+
/// Internal pointer type permits fast address arithmetic
|
| 390 |
+
using BytePointer = char *;
|
| 391 |
+
|
| 392 |
+
private:
|
| 393 |
+
//
|
| 394 |
+
// Data members
|
| 395 |
+
//
|
| 396 |
+
|
| 397 |
+
UnderlyingPredicates the_predicates;
|
| 398 |
+
|
| 399 |
+
/// Parameters object with precomputed internal state
|
| 400 |
+
Params params_;
|
| 401 |
+
|
| 402 |
+
/// Internal pointer to first access of tile
|
| 403 |
+
BytePointer pointer_;
|
| 404 |
+
|
| 405 |
+
/// Used for out-of-order visitation
|
| 406 |
+
bool is_residue_tile_;
|
| 407 |
+
|
| 408 |
+
/// Below is used when Gather is turned on. We need to record strided_offset
|
| 409 |
+
/// and contiguous_offset separated to compute the offset by using
|
| 410 |
+
///
|
| 411 |
+
/// offset = contiguous_offset + indices[strided_offset]
|
| 412 |
+
|
| 413 |
+
/// Gather indices
|
| 414 |
+
int const *indices_;
|
| 415 |
+
|
| 416 |
+
/// Function to perform layout permutation and offset computation
|
| 417 |
+
PermuteLayout permute_layout_;
|
| 418 |
+
|
| 419 |
+
/// Tracks thread's coordinate offset in the matrix for current tile.
|
| 420 |
+
/// This is only used in the following cases:
|
| 421 |
+
/// - when Gather is true, strided coordinate needed to access indices (contiguous offset is tracked via pointer_)
|
| 422 |
+
/// - when Permute is true, both coordinates are needed as input into permutation function (pointer_ is fixed)
|
| 423 |
+
TensorCoord coord_offset_;
|
| 424 |
+
|
| 425 |
+
private:
|
| 426 |
+
/// Computes predicates based on internally tracked per-thread offset.
|
| 427 |
+
CUTLASS_DEVICE
|
| 428 |
+
void compute_predicates_(
|
| 429 |
+
/// Extent of the matrix window
|
| 430 |
+
TensorCoord extent,
|
| 431 |
+
/// optionally, simplify predicate calculation during 'steady state' phase
|
| 432 |
+
bool is_steady_state = false) {
|
| 433 |
+
the_predicates.compute_predicates_(extent, is_steady_state);
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
public:
|
| 437 |
+
|
| 438 |
+
/// Default constructor
|
| 439 |
+
PredicatedTileAccessIterator() = default;
|
| 440 |
+
|
| 441 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 442 |
+
/// and thread ID
|
| 443 |
+
CUTLASS_HOST_DEVICE
|
| 444 |
+
PredicatedTileAccessIterator(
|
| 445 |
+
/// Precomputed parameters object
|
| 446 |
+
Params const ¶ms,
|
| 447 |
+
/// Pointer to start of tensor
|
| 448 |
+
Pointer pointer,
|
| 449 |
+
/// Extent of tensor
|
| 450 |
+
TensorCoord extent,
|
| 451 |
+
/// ID of each participating thread
|
| 452 |
+
int thread_id,
|
| 453 |
+
/// Initial offset of threadblock
|
| 454 |
+
TensorCoord const &threadblock_offset,
|
| 455 |
+
/// Gather indices
|
| 456 |
+
int const *indices = nullptr)
|
| 457 |
+
: params_(params),
|
| 458 |
+
pointer_(reinterpret_cast<BytePointer>(
|
| 459 |
+
const_cast<NonConstPointer>(pointer))),
|
| 460 |
+
the_predicates(extent),
|
| 461 |
+
is_residue_tile_(true),
|
| 462 |
+
indices_(indices),
|
| 463 |
+
permute_layout_(TensorCoord(extent.contiguous(), extent.strided()), params.stride_) {
|
| 464 |
+
|
| 465 |
+
the_predicates.set_predicates(thread_id, threadblock_offset);
|
| 466 |
+
|
| 467 |
+
if (Gather) {
|
| 468 |
+
assert(indices_);
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
// update internal pointers
|
| 472 |
+
Layout layout(params_.stride_);
|
| 473 |
+
|
| 474 |
+
if (!Gather && !Permute) {
|
| 475 |
+
add_pointer_offset(layout(the_predicates.thread_offset_));
|
| 476 |
+
} else {
|
| 477 |
+
coord_offset_ = the_predicates.thread_offset_;
|
| 478 |
+
if (!Permute) {
|
| 479 |
+
add_pointer_offset(layout(make_Coord(coord_offset_.contiguous(), 0)));
|
| 480 |
+
}
|
| 481 |
+
}
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
|
| 485 |
+
CUTLASS_HOST_DEVICE
|
| 486 |
+
PredicatedTileAccessIterator(
|
| 487 |
+
/// Precomputed parameters object
|
| 488 |
+
Params const ¶ms,
|
| 489 |
+
/// Pointer to start of tensor
|
| 490 |
+
Pointer pointer,
|
| 491 |
+
/// Extent of tensor
|
| 492 |
+
TensorCoord extent,
|
| 493 |
+
///< ID of each participating thread
|
| 494 |
+
int thread_id)
|
| 495 |
+
: PredicatedTileAccessIterator(params, pointer, extent, thread_id,
|
| 496 |
+
make_Coord(0, 0)) {}
|
| 497 |
+
|
| 498 |
+
/// Overrides the internal iteration index
|
| 499 |
+
CUTLASS_HOST_DEVICE
|
| 500 |
+
void set_iteration_index(int index) {
|
| 501 |
+
the_predicates.set_iteration_index(index);
|
| 502 |
+
}
|
| 503 |
+
|
| 504 |
+
/// Adds a pointer offset in units of Element
|
| 505 |
+
CUTLASS_HOST_DEVICE
|
| 506 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 507 |
+
pointer_ += sizeof_bits<Element>::value * pointer_offset / 8;
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
|
| 511 |
+
CUTLASS_DEVICE
|
| 512 |
+
void add_tile_offset(
|
| 513 |
+
TensorCoord const &tile_offset) {
|
| 514 |
+
if (is_residue_tile_) {
|
| 515 |
+
|
| 516 |
+
the_predicates.thread_offset_ += the_predicates.residue_offset_;
|
| 517 |
+
|
| 518 |
+
the_predicates.compute_predicates_(the_predicates.extent_, true);
|
| 519 |
+
|
| 520 |
+
Layout layout(params_.stride_);
|
| 521 |
+
|
| 522 |
+
if (!Gather && !Permute) {
|
| 523 |
+
add_pointer_offset(layout(the_predicates.residue_offset_));
|
| 524 |
+
|
| 525 |
+
if (kAdvanceRank) {
|
| 526 |
+
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1);
|
| 527 |
+
pointer_ += Shape::kContiguous * tile_offset.contiguous() * sizeof_bits<Element>::value / 8;
|
| 528 |
+
} else {
|
| 529 |
+
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1);
|
| 530 |
+
pointer_ += Shape::kStrided * tile_offset.strided() * sizeof_bits<Element>::value / 8;
|
| 531 |
+
}
|
| 532 |
+
} else {
|
| 533 |
+
coord_offset_.strided() = the_predicates.thread_offset_.strided() + Shape::kStrided * (tile_offset.strided() - kAdvanceRank);
|
| 534 |
+
if (!Permute) {
|
| 535 |
+
add_pointer_offset(layout(make_Coord(the_predicates.residue_offset_.contiguous(), 0)));
|
| 536 |
+
add_pointer_offset(Shape::kContiguous * (tile_offset.contiguous() - (1 - kAdvanceRank)));
|
| 537 |
+
} else {
|
| 538 |
+
coord_offset_.contiguous() = the_predicates.thread_offset_.contiguous() + Shape::kContiguous * (tile_offset.contiguous() - (1 - kAdvanceRank));
|
| 539 |
+
}
|
| 540 |
+
}
|
| 541 |
+
} else {
|
| 542 |
+
if (!Gather && !Permute) {
|
| 543 |
+
if (kAdvanceRank) {
|
| 544 |
+
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided());
|
| 545 |
+
pointer_ += Shape::kContiguous * tile_offset.contiguous();
|
| 546 |
+
} else {
|
| 547 |
+
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous());
|
| 548 |
+
pointer_ += Shape::kStrided * tile_offset.strided();
|
| 549 |
+
}
|
| 550 |
+
} else {
|
| 551 |
+
coord_offset_.strided() += Shape::kStrided * tile_offset.strided();
|
| 552 |
+
if (!Permute) {
|
| 553 |
+
add_pointer_offset(Shape::kContiguous * tile_offset.contiguous());
|
| 554 |
+
} else {
|
| 555 |
+
coord_offset_.contiguous() += Shape::kContiguous * tile_offset.contiguous();
|
| 556 |
+
}
|
| 557 |
+
}
|
| 558 |
+
}
|
| 559 |
+
|
| 560 |
+
is_residue_tile_ = false;
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
/// Returns a pointer
|
| 564 |
+
CUTLASS_HOST_DEVICE
|
| 565 |
+
AccessType *get() const {
|
| 566 |
+
|
| 567 |
+
if (Gather || Permute)
|
| 568 |
+
{
|
| 569 |
+
if (!valid()) {
|
| 570 |
+
return nullptr;
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
Index coord_contig = (Permute ? coord_offset_.contiguous() : 0) + the_predicates.iteration_contiguous_ * ThreadMap::Delta::kContiguous + the_predicates.iteration_vector_ * AccessType::kElements;
|
| 574 |
+
Index coord_strided = coord_offset_.strided() + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided;
|
| 575 |
+
if (Gather) {
|
| 576 |
+
coord_strided = indices_[coord_strided];
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
LongIndex offset = Permute ? permute_layout_(TensorCoord(coord_contig, coord_strided)) : (coord_strided * LongIndex(params_.stride_) + coord_contig);
|
| 580 |
+
return reinterpret_cast<AccessType *>(pointer_ + OffsetBytes<Element>(offset));
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
return reinterpret_cast<AccessType *>(
|
| 584 |
+
pointer_ +
|
| 585 |
+
the_predicates.iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value) / 8) + the_predicates.iteration_vector_;
|
| 586 |
+
}
|
| 587 |
+
|
| 588 |
+
/// Increment and return an instance to self.
|
| 589 |
+
CUTLASS_HOST_DEVICE
|
| 590 |
+
PredicatedTileAccessIterator &operator++() {
|
| 591 |
+
|
| 592 |
+
the_predicates.operator++();
|
| 593 |
+
|
| 594 |
+
++the_predicates.iteration_vector_;
|
| 595 |
+
if (the_predicates.iteration_vector_ < kAccessesPerVector) {
|
| 596 |
+
return *this;
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
the_predicates.iteration_vector_ = 0;
|
| 600 |
+
++the_predicates.iteration_contiguous_;
|
| 601 |
+
|
| 602 |
+
if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
| 603 |
+
return *this;
|
| 604 |
+
}
|
| 605 |
+
|
| 606 |
+
// Enter here only if (iteration_contiguous_ == ThreadMap::Iteration::kContiguous)
|
| 607 |
+
the_predicates.iteration_contiguous_ = 0;
|
| 608 |
+
++the_predicates.iteration_strided_;
|
| 609 |
+
|
| 610 |
+
if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
| 611 |
+
if (!Gather && !Permute) {
|
| 612 |
+
pointer_ += params_.inc_strided_;
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
return *this;
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
// Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
|
| 619 |
+
// which means we enter the next tile.
|
| 620 |
+
the_predicates.iteration_strided_ = 0;
|
| 621 |
+
|
| 622 |
+
if (!Gather && !Permute) {
|
| 623 |
+
// advance to next tile
|
| 624 |
+
pointer_ += params_.inc_next_;
|
| 625 |
+
|
| 626 |
+
// now return to start tile - if the iterator is subsequently advanced, this
|
| 627 |
+
// subtraction as well as the subsequent integer addition are both elided by
|
| 628 |
+
// the compiler.
|
| 629 |
+
pointer_ -= params_.inc_advance_;
|
| 630 |
+
}
|
| 631 |
+
|
| 632 |
+
return *this;
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
/// Increment and return an instance to self.
|
| 636 |
+
CUTLASS_HOST_DEVICE
|
| 637 |
+
PredicatedTileAccessIterator operator++(int) {
|
| 638 |
+
PredicatedTileAccessIterator self(*this);
|
| 639 |
+
operator++();
|
| 640 |
+
return self;
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
+
/// Clears the predicate set efficiently
|
| 644 |
+
CUTLASS_HOST_DEVICE
|
| 645 |
+
void clear_mask(bool enable = true) {
|
| 646 |
+
the_predicates.clear_mask(enable);
|
| 647 |
+
}
|
| 648 |
+
|
| 649 |
+
/// Clears the predicate set efficiently
|
| 650 |
+
CUTLASS_HOST_DEVICE
|
| 651 |
+
void enable_mask() {
|
| 652 |
+
the_predicates.enable_mask();
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 656 |
+
CUTLASS_HOST_DEVICE
|
| 657 |
+
void set_mask(Mask const &mask) {
|
| 658 |
+
the_predicates.set_mask(mask);
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
/// Gets the mask
|
| 662 |
+
CUTLASS_HOST_DEVICE
|
| 663 |
+
void get_mask(Mask &mask) {
|
| 664 |
+
the_predicates.get_mask(mask);
|
| 665 |
+
}
|
| 666 |
+
|
| 667 |
+
/// Returns whether access is valid or not
|
| 668 |
+
CUTLASS_HOST_DEVICE
|
| 669 |
+
bool valid() const {
|
| 670 |
+
return the_predicates.valid();
|
| 671 |
+
}
|
| 672 |
+
};
|
| 673 |
+
|
| 674 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 675 |
+
|
| 676 |
+
/// Specialization of PredicatedTileAccessIterator for column-major data.
|
| 677 |
+
///
|
| 678 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 679 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 680 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 681 |
+
/// MaskedTileIteratorConcept
|
| 682 |
+
///
|
| 683 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 684 |
+
typename ThreadMap_, typename AccessType_, bool Gather,
|
| 685 |
+
typename PermuteLayout>
|
| 686 |
+
class PredicatedTileAccessIterator<Shape_, Element_, layout::ColumnMajor,
|
| 687 |
+
AdvanceRank, ThreadMap_, AccessType_, Gather,
|
| 688 |
+
PermuteLayout> {
|
| 689 |
+
public:
|
| 690 |
+
static_assert(
|
| 691 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 692 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 693 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 694 |
+
|
| 695 |
+
using Shape = Shape_;
|
| 696 |
+
using Element = Element_;
|
| 697 |
+
using Layout = layout::ColumnMajor;
|
| 698 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 699 |
+
using ThreadMap = ThreadMap_;
|
| 700 |
+
using AccessType = AccessType_;
|
| 701 |
+
|
| 702 |
+
using Index = typename Layout::Index;
|
| 703 |
+
using LongIndex = typename Layout::LongIndex;
|
| 704 |
+
|
| 705 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 706 |
+
using TensorView = TensorView<Element, Layout>;
|
| 707 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 708 |
+
|
| 709 |
+
using Pointer = Element *;
|
| 710 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 711 |
+
|
| 712 |
+
using UnderlyingIterator = PredicatedTileAccessIterator<
|
| 713 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
|
| 714 |
+
layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType,
|
| 715 |
+
Gather, PermuteLayout>;
|
| 716 |
+
|
| 717 |
+
/// Predicate vector stores mask to guard accesses
|
| 718 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 719 |
+
|
| 720 |
+
static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
|
| 721 |
+
|
| 722 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 723 |
+
class Params {
|
| 724 |
+
private:
|
| 725 |
+
friend PredicatedTileAccessIterator;
|
| 726 |
+
|
| 727 |
+
/// Parameters object
|
| 728 |
+
typename UnderlyingIterator::Params params_;
|
| 729 |
+
|
| 730 |
+
public:
|
| 731 |
+
|
| 732 |
+
/// Default constructor
|
| 733 |
+
Params() = default;
|
| 734 |
+
|
| 735 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 736 |
+
CUTLASS_HOST_DEVICE
|
| 737 |
+
Params(Layout const &layout)
|
| 738 |
+
: params_(layout::PitchLinear(layout.stride(0))){};
|
| 739 |
+
|
| 740 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 741 |
+
CUTLASS_HOST_DEVICE
|
| 742 |
+
Params(typename UnderlyingIterator::Params::Base const &base)
|
| 743 |
+
: params_(base) {}
|
| 744 |
+
};
|
| 745 |
+
|
| 746 |
+
private:
|
| 747 |
+
//
|
| 748 |
+
// Data members
|
| 749 |
+
//
|
| 750 |
+
|
| 751 |
+
/// Underlying pitch-linear tile iterator
|
| 752 |
+
UnderlyingIterator iterator_;
|
| 753 |
+
|
| 754 |
+
public:
|
| 755 |
+
|
| 756 |
+
/// Default constructor
|
| 757 |
+
PredicatedTileAccessIterator() = default;
|
| 758 |
+
|
| 759 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 760 |
+
/// and thread ID
|
| 761 |
+
CUTLASS_HOST_DEVICE
|
| 762 |
+
PredicatedTileAccessIterator(
|
| 763 |
+
///< Precomputed parameters object
|
| 764 |
+
Params const ¶ms,
|
| 765 |
+
///< Pointer to start of tensor
|
| 766 |
+
Pointer pointer,
|
| 767 |
+
///< Extent of tensor
|
| 768 |
+
TensorCoord extent,
|
| 769 |
+
///< ID of each participating thread
|
| 770 |
+
int thread_id,
|
| 771 |
+
///< Initial offset of threadblock
|
| 772 |
+
TensorCoord const &threadblock_offset,
|
| 773 |
+
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
| 774 |
+
)
|
| 775 |
+
: iterator_(params.params_, pointer,
|
| 776 |
+
layout::PitchLinearCoord(extent.row(), extent.column()),
|
| 777 |
+
thread_id,
|
| 778 |
+
layout::PitchLinearCoord(threadblock_offset.row(),
|
| 779 |
+
threadblock_offset.column()),
|
| 780 |
+
indices) {}
|
| 781 |
+
|
| 782 |
+
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
|
| 783 |
+
CUTLASS_HOST_DEVICE
|
| 784 |
+
PredicatedTileAccessIterator(
|
| 785 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 786 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 787 |
+
TensorCoord extent, ///< Extent of tensor
|
| 788 |
+
int thread_id ///< ID of each participating thread
|
| 789 |
+
)
|
| 790 |
+
: PredicatedTileAccessIterator(params, pointer, extent, thread_id,
|
| 791 |
+
make_Coord(0, 0)) {}
|
| 792 |
+
|
| 793 |
+
/// Overrides the internal iteration index
|
| 794 |
+
CUTLASS_HOST_DEVICE
|
| 795 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 796 |
+
|
| 797 |
+
/// Adds a pointer offset in units of Element
|
| 798 |
+
CUTLASS_HOST_DEVICE
|
| 799 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 800 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 801 |
+
}
|
| 802 |
+
|
| 803 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 804 |
+
/// tiles
|
| 805 |
+
CUTLASS_HOST_DEVICE
|
| 806 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 807 |
+
iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
|
| 808 |
+
}
|
| 809 |
+
|
| 810 |
+
/// Returns a pointer
|
| 811 |
+
CUTLASS_HOST_DEVICE
|
| 812 |
+
AccessType *get() const {
|
| 813 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 814 |
+
}
|
| 815 |
+
|
| 816 |
+
/// Advances to the next tile in memory.
|
| 817 |
+
///
|
| 818 |
+
/// The first time this method is called, predicates are updated, and the
|
| 819 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 820 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 821 |
+
/// pointer.
|
| 822 |
+
CUTLASS_HOST_DEVICE
|
| 823 |
+
PredicatedTileAccessIterator &operator++() {
|
| 824 |
+
++iterator_;
|
| 825 |
+
return *this;
|
| 826 |
+
}
|
| 827 |
+
|
| 828 |
+
/// Advances to the next tile in memory.
|
| 829 |
+
///
|
| 830 |
+
/// The first time this method is called, predicates are updated, and the
|
| 831 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 832 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 833 |
+
/// pointer.
|
| 834 |
+
CUTLASS_HOST_DEVICE
|
| 835 |
+
PredicatedTileAccessIterator operator++(int) {
|
| 836 |
+
PredicatedTileAccessIterator self(*this);
|
| 837 |
+
operator++();
|
| 838 |
+
return self;
|
| 839 |
+
}
|
| 840 |
+
|
| 841 |
+
/// Clears the predicate set efficiently
|
| 842 |
+
CUTLASS_HOST_DEVICE
|
| 843 |
+
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
| 844 |
+
|
| 845 |
+
/// Clears the predicate set efficiently
|
| 846 |
+
CUTLASS_HOST_DEVICE
|
| 847 |
+
void enable_mask() { iterator_.enable_mask(); }
|
| 848 |
+
|
| 849 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 850 |
+
CUTLASS_HOST_DEVICE
|
| 851 |
+
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
| 852 |
+
|
| 853 |
+
/// Gets the mask
|
| 854 |
+
CUTLASS_HOST_DEVICE
|
| 855 |
+
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
| 856 |
+
|
| 857 |
+
/// Returns whether access is valid or not
|
| 858 |
+
CUTLASS_HOST_DEVICE
|
| 859 |
+
bool valid() {
|
| 860 |
+
return iterator_.valid();
|
| 861 |
+
}
|
| 862 |
+
};
|
| 863 |
+
|
| 864 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 865 |
+
|
| 866 |
+
/// Specialization of PredicatedTileAccessIterator for row-major data.
|
| 867 |
+
///
|
| 868 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 869 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 870 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 871 |
+
/// MaskedTileIteratorConcept
|
| 872 |
+
///
|
| 873 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 874 |
+
typename ThreadMap_, typename AccessType_, bool Gather,
|
| 875 |
+
typename PermuteLayout>
|
| 876 |
+
class PredicatedTileAccessIterator<Shape_, Element_, layout::RowMajor,
|
| 877 |
+
AdvanceRank, ThreadMap_, AccessType_, Gather,
|
| 878 |
+
PermuteLayout> {
|
| 879 |
+
public:
|
| 880 |
+
static_assert(
|
| 881 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 882 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 883 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 884 |
+
|
| 885 |
+
using Shape = Shape_;
|
| 886 |
+
using Element = Element_;
|
| 887 |
+
using Layout = layout::RowMajor;
|
| 888 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 889 |
+
using ThreadMap = ThreadMap_;
|
| 890 |
+
using AccessType = AccessType_;
|
| 891 |
+
|
| 892 |
+
using Index = typename Layout::Index;
|
| 893 |
+
using LongIndex = typename Layout::LongIndex;
|
| 894 |
+
|
| 895 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 896 |
+
using TensorView = TensorView<Element, Layout>;
|
| 897 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 898 |
+
|
| 899 |
+
using Pointer = Element *;
|
| 900 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 901 |
+
|
| 902 |
+
using UnderlyingIterator = PredicatedTileAccessIterator<
|
| 903 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
|
| 904 |
+
layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType,
|
| 905 |
+
Gather, PermuteLayout>;
|
| 906 |
+
|
| 907 |
+
static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
|
| 908 |
+
|
| 909 |
+
/// Predicate vector stores mask to guard accesses
|
| 910 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 911 |
+
|
| 912 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 913 |
+
class Params {
|
| 914 |
+
private:
|
| 915 |
+
friend PredicatedTileAccessIterator;
|
| 916 |
+
|
| 917 |
+
/// Parameters object
|
| 918 |
+
typename UnderlyingIterator::Params params_;
|
| 919 |
+
|
| 920 |
+
public:
|
| 921 |
+
|
| 922 |
+
/// Default constructor
|
| 923 |
+
Params() = default;
|
| 924 |
+
|
| 925 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 926 |
+
CUTLASS_HOST_DEVICE
|
| 927 |
+
Params(Layout const &layout)
|
| 928 |
+
: params_(layout::PitchLinear(layout.stride(0))){};
|
| 929 |
+
|
| 930 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 931 |
+
CUTLASS_HOST_DEVICE
|
| 932 |
+
Params(typename UnderlyingIterator::Params::Base const &base)
|
| 933 |
+
: params_(base) {}
|
| 934 |
+
};
|
| 935 |
+
|
| 936 |
+
private:
|
| 937 |
+
//
|
| 938 |
+
// Data members
|
| 939 |
+
//
|
| 940 |
+
|
| 941 |
+
/// Underlying pitch-linear tile iterator
|
| 942 |
+
UnderlyingIterator iterator_;
|
| 943 |
+
|
| 944 |
+
public:
|
| 945 |
+
|
| 946 |
+
/// Default constructor
|
| 947 |
+
PredicatedTileAccessIterator() = default;
|
| 948 |
+
|
| 949 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 950 |
+
/// and thread ID
|
| 951 |
+
CUTLASS_HOST_DEVICE
|
| 952 |
+
PredicatedTileAccessIterator(
|
| 953 |
+
///< Precomputed parameters object
|
| 954 |
+
Params const ¶ms,
|
| 955 |
+
///< Pointer to start of tensor
|
| 956 |
+
Pointer pointer,
|
| 957 |
+
///< Extent of tensor
|
| 958 |
+
TensorCoord extent,
|
| 959 |
+
///< ID of each participating thread
|
| 960 |
+
int thread_id,
|
| 961 |
+
///< Initial offset of threadblock
|
| 962 |
+
TensorCoord const &threadblock_offset,
|
| 963 |
+
/// Gather indices
|
| 964 |
+
int const *indices = nullptr)
|
| 965 |
+
: iterator_(params.params_, pointer,
|
| 966 |
+
layout::PitchLinearCoord(extent.column(), extent.row()),
|
| 967 |
+
thread_id,
|
| 968 |
+
layout::PitchLinearCoord(threadblock_offset.column(),
|
| 969 |
+
threadblock_offset.row()),
|
| 970 |
+
indices) {}
|
| 971 |
+
|
| 972 |
+
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
|
| 973 |
+
CUTLASS_HOST_DEVICE
|
| 974 |
+
PredicatedTileAccessIterator(
|
| 975 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 976 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 977 |
+
TensorCoord extent, ///< Extent of tensor
|
| 978 |
+
int thread_id ///< ID of each participating thread
|
| 979 |
+
)
|
| 980 |
+
: PredicatedTileAccessIterator(params, pointer, extent, thread_id,
|
| 981 |
+
make_Coord(0, 0)) {}
|
| 982 |
+
|
| 983 |
+
/// Overrides the internal iteration index
|
| 984 |
+
CUTLASS_HOST_DEVICE
|
| 985 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 986 |
+
|
| 987 |
+
/// Adds a pointer offset in units of Element
|
| 988 |
+
CUTLASS_HOST_DEVICE
|
| 989 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 990 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 991 |
+
}
|
| 992 |
+
|
| 993 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 994 |
+
/// tiles
|
| 995 |
+
CUTLASS_HOST_DEVICE
|
| 996 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 997 |
+
iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
|
| 998 |
+
}
|
| 999 |
+
|
| 1000 |
+
/// Returns a pointer
|
| 1001 |
+
CUTLASS_HOST_DEVICE
|
| 1002 |
+
AccessType *get() const {
|
| 1003 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 1004 |
+
}
|
| 1005 |
+
|
| 1006 |
+
/// Advances to the next tile in memory.
|
| 1007 |
+
///
|
| 1008 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1009 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1010 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1011 |
+
/// pointer.
|
| 1012 |
+
CUTLASS_HOST_DEVICE
|
| 1013 |
+
PredicatedTileAccessIterator &operator++() {
|
| 1014 |
+
++iterator_;
|
| 1015 |
+
return *this;
|
| 1016 |
+
}
|
| 1017 |
+
|
| 1018 |
+
/// Advances to the next tile in memory.
|
| 1019 |
+
///
|
| 1020 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1021 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1022 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1023 |
+
/// pointer.
|
| 1024 |
+
CUTLASS_HOST_DEVICE
|
| 1025 |
+
PredicatedTileAccessIterator operator++(int) {
|
| 1026 |
+
PredicatedTileAccessIterator self(*this);
|
| 1027 |
+
operator++();
|
| 1028 |
+
return self;
|
| 1029 |
+
}
|
| 1030 |
+
|
| 1031 |
+
/// Clears the predicate set efficiently
|
| 1032 |
+
CUTLASS_HOST_DEVICE
|
| 1033 |
+
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
| 1034 |
+
|
| 1035 |
+
/// Clears the predicate set efficiently
|
| 1036 |
+
CUTLASS_HOST_DEVICE
|
| 1037 |
+
void enable_mask() { iterator_.enable_mask(); }
|
| 1038 |
+
|
| 1039 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 1040 |
+
CUTLASS_HOST_DEVICE
|
| 1041 |
+
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
| 1042 |
+
|
| 1043 |
+
/// Gets the mask
|
| 1044 |
+
CUTLASS_HOST_DEVICE
|
| 1045 |
+
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
| 1046 |
+
|
| 1047 |
+
/// Returns whether access is valid or not
|
| 1048 |
+
CUTLASS_HOST_DEVICE
|
| 1049 |
+
bool valid() {
|
| 1050 |
+
return iterator_.valid();
|
| 1051 |
+
}
|
| 1052 |
+
};
|
| 1053 |
+
|
| 1054 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1055 |
+
|
| 1056 |
+
/// Specialization of PredicatedTileAccessIterator for affine rank 2 data.
|
| 1057 |
+
///
|
| 1058 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 1059 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 1060 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 1061 |
+
/// MaskedTileIteratorConcept
|
| 1062 |
+
///
|
| 1063 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 1064 |
+
typename ThreadMap_, typename AccessType_>
|
| 1065 |
+
class PredicatedTileAccessIterator<Shape_, Element_, layout::AffineRankN<2>,
|
| 1066 |
+
AdvanceRank, ThreadMap_, AccessType_, false,
|
| 1067 |
+
layout::NoPermute> {
|
| 1068 |
+
public:
|
| 1069 |
+
static_assert(
|
| 1070 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 1071 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 1072 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 1073 |
+
|
| 1074 |
+
using Shape = Shape_;
|
| 1075 |
+
using Element = Element_;
|
| 1076 |
+
using Layout = layout::AffineRankN<2>;
|
| 1077 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 1078 |
+
using ThreadMap = ThreadMap_;
|
| 1079 |
+
using AccessType = AccessType_;
|
| 1080 |
+
|
| 1081 |
+
using Index = typename Layout::Index;
|
| 1082 |
+
using LongIndex = typename Layout::LongIndex;
|
| 1083 |
+
|
| 1084 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 1085 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1086 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 1087 |
+
|
| 1088 |
+
using Pointer = Element *;
|
| 1089 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 1090 |
+
|
| 1091 |
+
using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates<
|
| 1092 |
+
Shape, Element, layout::PitchLinear, AdvanceRank, ThreadMap, AccessType>;
|
| 1093 |
+
|
| 1094 |
+
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
| 1095 |
+
|
| 1096 |
+
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
| 1097 |
+
"Vectors implied by the thread map must be divisible by the access type.");
|
| 1098 |
+
|
| 1099 |
+
/// Predicate vector stores mask to guard accesses
|
| 1100 |
+
using Mask = typename UnderlyingPredicates::Mask;
|
| 1101 |
+
|
| 1102 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 1103 |
+
class Params {
|
| 1104 |
+
public:
|
| 1105 |
+
friend PredicatedTileAccessIterator;
|
| 1106 |
+
|
| 1107 |
+
private:
|
| 1108 |
+
/// stride of pitch-linear layout (units of Element)
|
| 1109 |
+
Coord<Layout::kStrideRank, Layout::LongIndex> stride_;
|
| 1110 |
+
/// amount (in byte) to increment pointer to move to next access along
|
| 1111 |
+
/// contiguous dimension
|
| 1112 |
+
LongIndex inc_contiguous_;
|
| 1113 |
+
/// amount (in byte) to increment pointer from first access of current
|
| 1114 |
+
/// contiguous dimension to first access of next one.
|
| 1115 |
+
LongIndex inc_strided_;
|
| 1116 |
+
/// amount (in byte) to increment pointer from last access of current
|
| 1117 |
+
/// contiguous dimension to first access of next one.
|
| 1118 |
+
LongIndex inc_next_strided_;
|
| 1119 |
+
/// amount (in byte) to increment pointer from last access to first access
|
| 1120 |
+
/// of next tile
|
| 1121 |
+
LongIndex inc_next_;
|
| 1122 |
+
/// amount (in byte) to increment pointer from first access of current tile
|
| 1123 |
+
/// to first access of next tile
|
| 1124 |
+
LongIndex inc_advance_;
|
| 1125 |
+
|
| 1126 |
+
public:
|
| 1127 |
+
|
| 1128 |
+
// Default ctor
|
| 1129 |
+
CUTLASS_HOST_DEVICE
|
| 1130 |
+
Params(): stride_(0), inc_contiguous_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { }
|
| 1131 |
+
|
| 1132 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 1133 |
+
CUTLASS_HOST_DEVICE
|
| 1134 |
+
Params(Layout const &layout) : stride_({layout.stride(0), layout.stride(1)}) {
|
| 1135 |
+
inc_contiguous_ = (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) *
|
| 1136 |
+
sizeof_bits<Element>::value / 8;
|
| 1137 |
+
|
| 1138 |
+
inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) *
|
| 1139 |
+
sizeof_bits<Element>::value / 8;
|
| 1140 |
+
|
| 1141 |
+
inc_next_strided_ = inc_strided_ - LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_;
|
| 1142 |
+
|
| 1143 |
+
if (kAdvanceRank) {
|
| 1144 |
+
// advance along strided dimension
|
| 1145 |
+
inc_advance_ =
|
| 1146 |
+
Shape::kStrided * LongIndex(stride_[1]) * sizeof_bits<Element>::value / 8;
|
| 1147 |
+
} else {
|
| 1148 |
+
// advance along contiguous dimension
|
| 1149 |
+
inc_advance_ = Shape::kContiguous * stride_[0] * sizeof_bits<Element>::value / 8;
|
| 1150 |
+
}
|
| 1151 |
+
|
| 1152 |
+
inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_;
|
| 1153 |
+
};
|
| 1154 |
+
};
|
| 1155 |
+
|
| 1156 |
+
private:
|
| 1157 |
+
/// Internal pointer type permits fast address arithmetic
|
| 1158 |
+
using BytePointer = char *;
|
| 1159 |
+
|
| 1160 |
+
//
|
| 1161 |
+
// Data members
|
| 1162 |
+
//
|
| 1163 |
+
|
| 1164 |
+
/// Parameters object with precomputed internal state
|
| 1165 |
+
Params params_;
|
| 1166 |
+
|
| 1167 |
+
/// Internal pointer to first access of tile
|
| 1168 |
+
BytePointer pointer_;
|
| 1169 |
+
|
| 1170 |
+
UnderlyingPredicates the_predicates;
|
| 1171 |
+
|
| 1172 |
+
/// Used for out-of-order visitation
|
| 1173 |
+
bool is_residue_tile_;
|
| 1174 |
+
|
| 1175 |
+
private:
|
| 1176 |
+
/// Computes predicates based on internally tracked per-thread offset.
|
| 1177 |
+
CUTLASS_DEVICE
|
| 1178 |
+
void compute_predicates_(
|
| 1179 |
+
/// Extent of the matrix window
|
| 1180 |
+
TensorCoord extent,
|
| 1181 |
+
/// optionally, simplify predicate calculation during 'steady state' phase
|
| 1182 |
+
bool is_steady_state = false) {
|
| 1183 |
+
the_predicates.compute_predicates_(extent, is_steady_state);
|
| 1184 |
+
}
|
| 1185 |
+
|
| 1186 |
+
public:
|
| 1187 |
+
|
| 1188 |
+
/// Default constructor
|
| 1189 |
+
PredicatedTileAccessIterator() = default;
|
| 1190 |
+
|
| 1191 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 1192 |
+
/// and thread ID
|
| 1193 |
+
CUTLASS_HOST_DEVICE
|
| 1194 |
+
PredicatedTileAccessIterator(
|
| 1195 |
+
///< Precomputed parameters object
|
| 1196 |
+
Params const ¶ms,
|
| 1197 |
+
///< Pointer to start of tensor
|
| 1198 |
+
Pointer pointer,
|
| 1199 |
+
///< Extent of tensor
|
| 1200 |
+
TensorCoord extent,
|
| 1201 |
+
///< ID of each participating thread
|
| 1202 |
+
int thread_id,
|
| 1203 |
+
///< Initial offset of threadblock
|
| 1204 |
+
TensorCoord const &threadblock_offset,
|
| 1205 |
+
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
| 1206 |
+
)
|
| 1207 |
+
: params_(params),
|
| 1208 |
+
pointer_(reinterpret_cast<BytePointer>(
|
| 1209 |
+
const_cast<NonConstPointer>(pointer))),
|
| 1210 |
+
the_predicates(extent),
|
| 1211 |
+
is_residue_tile_(true) {
|
| 1212 |
+
|
| 1213 |
+
the_predicates.set_predicates(thread_id, threadblock_offset);
|
| 1214 |
+
|
| 1215 |
+
// update internal pointers
|
| 1216 |
+
Layout layout(params_.stride_);
|
| 1217 |
+
add_pointer_offset(layout(the_predicates.thread_offset_));
|
| 1218 |
+
}
|
| 1219 |
+
|
| 1220 |
+
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
|
| 1221 |
+
CUTLASS_HOST_DEVICE
|
| 1222 |
+
PredicatedTileAccessIterator(
|
| 1223 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 1224 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 1225 |
+
TensorCoord extent, ///< Extent of tensor
|
| 1226 |
+
int thread_id ///< ID of each participating thread
|
| 1227 |
+
)
|
| 1228 |
+
: PredicatedTileAccessIterator(params, pointer, extent, thread_id,
|
| 1229 |
+
make_Coord(0, 0)) {}
|
| 1230 |
+
|
| 1231 |
+
/// Overrides the internal iteration index
|
| 1232 |
+
CUTLASS_HOST_DEVICE
|
| 1233 |
+
void set_iteration_index(int index) { the_predicates.set_iteration_index(index); }
|
| 1234 |
+
|
| 1235 |
+
/// Adds a pointer offset in units of Element
|
| 1236 |
+
CUTLASS_HOST_DEVICE
|
| 1237 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1238 |
+
pointer_ += sizeof_bits<Element>::value * pointer_offset / 8;
|
| 1239 |
+
}
|
| 1240 |
+
|
| 1241 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 1242 |
+
/// tiles
|
| 1243 |
+
CUTLASS_HOST_DEVICE
|
| 1244 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 1245 |
+
if (is_residue_tile_) {
|
| 1246 |
+
|
| 1247 |
+
the_predicates.thread_offset_ += the_predicates.residue_offset_;
|
| 1248 |
+
|
| 1249 |
+
Layout layout(params_.stride_);
|
| 1250 |
+
add_pointer_offset(layout(the_predicates.residue_offset_));
|
| 1251 |
+
|
| 1252 |
+
the_predicates.compute_predicates_(the_predicates.extent_, true);
|
| 1253 |
+
|
| 1254 |
+
if (kAdvanceRank) {
|
| 1255 |
+
pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1] - 1);
|
| 1256 |
+
pointer_ += Shape::kContiguous * tile_offset[0];
|
| 1257 |
+
} else {
|
| 1258 |
+
pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0] - 1);
|
| 1259 |
+
pointer_ += Shape::kStrided * tile_offset[1];
|
| 1260 |
+
}
|
| 1261 |
+
} else {
|
| 1262 |
+
if (kAdvanceRank) {
|
| 1263 |
+
pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]);
|
| 1264 |
+
pointer_ += Shape::kContiguous * tile_offset[0];
|
| 1265 |
+
} else {
|
| 1266 |
+
pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]);
|
| 1267 |
+
pointer_ += Shape::kStrided * tile_offset[1];
|
| 1268 |
+
}
|
| 1269 |
+
}
|
| 1270 |
+
is_residue_tile_ = false;
|
| 1271 |
+
}
|
| 1272 |
+
|
| 1273 |
+
/// Returns a pointer
|
| 1274 |
+
CUTLASS_HOST_DEVICE
|
| 1275 |
+
AccessType *get() const {
|
| 1276 |
+
return reinterpret_cast<AccessType *>(pointer_) + the_predicates.iteration_vector_;
|
| 1277 |
+
}
|
| 1278 |
+
|
| 1279 |
+
/// Advances to the next tile in memory.
|
| 1280 |
+
///
|
| 1281 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1282 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1283 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1284 |
+
/// pointer.
|
| 1285 |
+
CUTLASS_HOST_DEVICE
|
| 1286 |
+
PredicatedTileAccessIterator &operator++() {
|
| 1287 |
+
the_predicates.operator++();
|
| 1288 |
+
++the_predicates.iteration_vector_;
|
| 1289 |
+
if (the_predicates.iteration_vector_ < kAccessesPerVector) {
|
| 1290 |
+
return *this;
|
| 1291 |
+
}
|
| 1292 |
+
|
| 1293 |
+
the_predicates.iteration_vector_ = 0;
|
| 1294 |
+
++the_predicates.iteration_contiguous_;
|
| 1295 |
+
|
| 1296 |
+
if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
| 1297 |
+
pointer_ += params_.inc_contiguous_;
|
| 1298 |
+
return *this;
|
| 1299 |
+
}
|
| 1300 |
+
|
| 1301 |
+
// Enter here only if (iteration_contiguous_ ==
|
| 1302 |
+
// ThreadMap::Iteration::kContiguous)
|
| 1303 |
+
the_predicates.iteration_contiguous_ = 0;
|
| 1304 |
+
++the_predicates.iteration_strided_;
|
| 1305 |
+
|
| 1306 |
+
if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
| 1307 |
+
pointer_ += params_.inc_next_strided_;
|
| 1308 |
+
return *this;
|
| 1309 |
+
}
|
| 1310 |
+
|
| 1311 |
+
// Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
|
| 1312 |
+
// which means we enter the next tile.
|
| 1313 |
+
the_predicates.iteration_strided_ = 0;
|
| 1314 |
+
|
| 1315 |
+
// advance to next tile
|
| 1316 |
+
pointer_ += params_.inc_next_;
|
| 1317 |
+
|
| 1318 |
+
// now return to start tile - if the iterator is subsequently advanced, this
|
| 1319 |
+
// subtraction as well as the subsequent integer addition are both elided by
|
| 1320 |
+
// the compiler.
|
| 1321 |
+
pointer_ -= params_.inc_advance_;
|
| 1322 |
+
|
| 1323 |
+
return *this;
|
| 1324 |
+
}
|
| 1325 |
+
|
| 1326 |
+
/// Advances to the next tile in memory.
|
| 1327 |
+
///
|
| 1328 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1329 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1330 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1331 |
+
/// pointer.
|
| 1332 |
+
CUTLASS_HOST_DEVICE
|
| 1333 |
+
PredicatedTileAccessIterator operator++(int) {
|
| 1334 |
+
PredicatedTileAccessIterator self(*this);
|
| 1335 |
+
operator++();
|
| 1336 |
+
return self;
|
| 1337 |
+
}
|
| 1338 |
+
|
| 1339 |
+
/// Clears the predicate set efficiently
|
| 1340 |
+
CUTLASS_HOST_DEVICE
|
| 1341 |
+
void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); }
|
| 1342 |
+
|
| 1343 |
+
/// Clears the predicate set efficiently
|
| 1344 |
+
CUTLASS_HOST_DEVICE
|
| 1345 |
+
void enable_mask() { the_predicates.enable_mask(); }
|
| 1346 |
+
|
| 1347 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 1348 |
+
CUTLASS_HOST_DEVICE
|
| 1349 |
+
void set_mask(Mask const &mask) { the_predicates.set_mask(mask); }
|
| 1350 |
+
|
| 1351 |
+
/// Gets the mask
|
| 1352 |
+
CUTLASS_HOST_DEVICE
|
| 1353 |
+
void get_mask(Mask &mask) { the_predicates.get_mask(mask); }
|
| 1354 |
+
|
| 1355 |
+
/// Returns whether access is valid or not
|
| 1356 |
+
CUTLASS_HOST_DEVICE
|
| 1357 |
+
bool valid() {
|
| 1358 |
+
return the_predicates.valid();
|
| 1359 |
+
}
|
| 1360 |
+
};
|
| 1361 |
+
|
| 1362 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1363 |
+
|
| 1364 |
+
/// Specialization of PredicatedTileAccessIterator for affine rank 2 column-major data.
|
| 1365 |
+
///
|
| 1366 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 1367 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 1368 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 1369 |
+
/// MaskedTileIteratorConcept
|
| 1370 |
+
///
|
| 1371 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 1372 |
+
typename ThreadMap_, typename AccessType_>
|
| 1373 |
+
class PredicatedTileAccessIterator<Shape_, Element_, layout::AffineRank2ColumnMajor,
|
| 1374 |
+
AdvanceRank, ThreadMap_, AccessType_, false,
|
| 1375 |
+
layout::NoPermute> {
|
| 1376 |
+
public:
|
| 1377 |
+
static_assert(
|
| 1378 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 1379 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 1380 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 1381 |
+
|
| 1382 |
+
using Shape = Shape_;
|
| 1383 |
+
using Element = Element_;
|
| 1384 |
+
using Layout = layout::AffineRank2ColumnMajor;
|
| 1385 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 1386 |
+
using ThreadMap = ThreadMap_;
|
| 1387 |
+
using AccessType = AccessType_;
|
| 1388 |
+
|
| 1389 |
+
using Index = typename Layout::Index;
|
| 1390 |
+
using LongIndex = typename Layout::LongIndex;
|
| 1391 |
+
|
| 1392 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 1393 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1394 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 1395 |
+
|
| 1396 |
+
using Pointer = Element *;
|
| 1397 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 1398 |
+
|
| 1399 |
+
// Map to the underlying AffineRankN<2> layout
|
| 1400 |
+
using UnderlyingIterator = PredicatedTileAccessIterator<
|
| 1401 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
|
| 1402 |
+
layout::AffineRankN<2>, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>;
|
| 1403 |
+
|
| 1404 |
+
static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
|
| 1405 |
+
|
| 1406 |
+
/// Predicate vector stores mask to guard accesses
|
| 1407 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 1408 |
+
|
| 1409 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 1410 |
+
class Params {
|
| 1411 |
+
private:
|
| 1412 |
+
friend PredicatedTileAccessIterator;
|
| 1413 |
+
|
| 1414 |
+
/// Parameters object
|
| 1415 |
+
typename UnderlyingIterator::Params params_;
|
| 1416 |
+
|
| 1417 |
+
public:
|
| 1418 |
+
|
| 1419 |
+
/// Default constructor
|
| 1420 |
+
Params() = default;
|
| 1421 |
+
|
| 1422 |
+
/// Construct the Params object given an AffineRankN<2> tensor's layout
|
| 1423 |
+
CUTLASS_HOST_DEVICE
|
| 1424 |
+
Params(Layout const &layout)
|
| 1425 |
+
: params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){};
|
| 1426 |
+
};
|
| 1427 |
+
|
| 1428 |
+
private:
|
| 1429 |
+
//
|
| 1430 |
+
// Data members
|
| 1431 |
+
//
|
| 1432 |
+
|
| 1433 |
+
/// Underlying AffineRankN<2> tile iterator
|
| 1434 |
+
UnderlyingIterator iterator_;
|
| 1435 |
+
|
| 1436 |
+
public:
|
| 1437 |
+
|
| 1438 |
+
/// Default constructor
|
| 1439 |
+
PredicatedTileAccessIterator() = default;
|
| 1440 |
+
|
| 1441 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 1442 |
+
/// and thread ID
|
| 1443 |
+
CUTLASS_HOST_DEVICE
|
| 1444 |
+
PredicatedTileAccessIterator(
|
| 1445 |
+
///< Precomputed parameters object
|
| 1446 |
+
Params const ¶ms,
|
| 1447 |
+
///< Pointer to start of tensor
|
| 1448 |
+
Pointer pointer,
|
| 1449 |
+
///< Extent of tensor
|
| 1450 |
+
TensorCoord extent,
|
| 1451 |
+
///< ID of each participating thread
|
| 1452 |
+
int thread_id,
|
| 1453 |
+
///< Initial offset of threadblock
|
| 1454 |
+
TensorCoord const &threadblock_offset,
|
| 1455 |
+
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
| 1456 |
+
)
|
| 1457 |
+
: iterator_(params.params_, pointer,
|
| 1458 |
+
layout::PitchLinearCoord(extent.row(), extent.column()),
|
| 1459 |
+
thread_id,
|
| 1460 |
+
layout::PitchLinearCoord(threadblock_offset.row(),
|
| 1461 |
+
threadblock_offset.column())) {}
|
| 1462 |
+
|
| 1463 |
+
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
|
| 1464 |
+
CUTLASS_HOST_DEVICE
|
| 1465 |
+
PredicatedTileAccessIterator(
|
| 1466 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 1467 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 1468 |
+
TensorCoord extent, ///< Extent of tensor
|
| 1469 |
+
int thread_id ///< ID of each participating thread
|
| 1470 |
+
)
|
| 1471 |
+
: PredicatedTileAccessIterator(params, pointer, extent, thread_id,
|
| 1472 |
+
make_Coord(0, 0)) {}
|
| 1473 |
+
|
| 1474 |
+
/// Overrides the internal iteration index
|
| 1475 |
+
CUTLASS_HOST_DEVICE
|
| 1476 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 1477 |
+
|
| 1478 |
+
/// Adds a pointer offset in units of Element
|
| 1479 |
+
CUTLASS_HOST_DEVICE
|
| 1480 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1481 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 1482 |
+
}
|
| 1483 |
+
|
| 1484 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 1485 |
+
/// tiles
|
| 1486 |
+
CUTLASS_HOST_DEVICE
|
| 1487 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 1488 |
+
iterator_.add_tile_offset(make_Coord(tile_offset.row(), tile_offset.column()));
|
| 1489 |
+
}
|
| 1490 |
+
|
| 1491 |
+
/// Returns a pointer
|
| 1492 |
+
CUTLASS_HOST_DEVICE
|
| 1493 |
+
AccessType *get() const {
|
| 1494 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 1495 |
+
}
|
| 1496 |
+
|
| 1497 |
+
/// Advances to the next tile in memory.
|
| 1498 |
+
///
|
| 1499 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1500 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1501 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1502 |
+
/// pointer.
|
| 1503 |
+
CUTLASS_HOST_DEVICE
|
| 1504 |
+
PredicatedTileAccessIterator &operator++() {
|
| 1505 |
+
++iterator_;
|
| 1506 |
+
return *this;
|
| 1507 |
+
}
|
| 1508 |
+
|
| 1509 |
+
/// Advances to the next tile in memory.
|
| 1510 |
+
///
|
| 1511 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1512 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1513 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1514 |
+
/// pointer.
|
| 1515 |
+
CUTLASS_HOST_DEVICE
|
| 1516 |
+
PredicatedTileAccessIterator operator++(int) {
|
| 1517 |
+
PredicatedTileAccessIterator self(*this);
|
| 1518 |
+
operator++();
|
| 1519 |
+
return self;
|
| 1520 |
+
}
|
| 1521 |
+
|
| 1522 |
+
/// Clears the predicate set efficiently
|
| 1523 |
+
CUTLASS_HOST_DEVICE
|
| 1524 |
+
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
| 1525 |
+
|
| 1526 |
+
/// Clears the predicate set efficiently
|
| 1527 |
+
CUTLASS_HOST_DEVICE
|
| 1528 |
+
void enable_mask() { iterator_.enable_mask(); }
|
| 1529 |
+
|
| 1530 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 1531 |
+
CUTLASS_HOST_DEVICE
|
| 1532 |
+
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
| 1533 |
+
|
| 1534 |
+
/// Gets the mask
|
| 1535 |
+
CUTLASS_HOST_DEVICE
|
| 1536 |
+
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
| 1537 |
+
|
| 1538 |
+
/// Returns whether access is valid or not
|
| 1539 |
+
CUTLASS_HOST_DEVICE
|
| 1540 |
+
bool valid() {
|
| 1541 |
+
return iterator_.valid();
|
| 1542 |
+
}
|
| 1543 |
+
};
|
| 1544 |
+
|
| 1545 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1546 |
+
|
| 1547 |
+
/// Specialization of PredicatedTileAccessIterator for affine rank-2 row-major data.
|
| 1548 |
+
///
|
| 1549 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 1550 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 1551 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 1552 |
+
/// MaskedTileIteratorConcept
|
| 1553 |
+
///
|
| 1554 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 1555 |
+
typename ThreadMap_, typename AccessType_>
|
| 1556 |
+
class PredicatedTileAccessIterator<Shape_, Element_, layout::AffineRank2RowMajor,
|
| 1557 |
+
AdvanceRank, ThreadMap_, AccessType_, false,
|
| 1558 |
+
layout::NoPermute> {
|
| 1559 |
+
public:
|
| 1560 |
+
static_assert(
|
| 1561 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 1562 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 1563 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 1564 |
+
|
| 1565 |
+
using Shape = Shape_;
|
| 1566 |
+
using Element = Element_;
|
| 1567 |
+
using Layout = layout::AffineRank2RowMajor;
|
| 1568 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 1569 |
+
using ThreadMap = ThreadMap_;
|
| 1570 |
+
using AccessType = AccessType_;
|
| 1571 |
+
|
| 1572 |
+
using Index = typename Layout::Index;
|
| 1573 |
+
using LongIndex = typename Layout::LongIndex;
|
| 1574 |
+
|
| 1575 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 1576 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1577 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 1578 |
+
|
| 1579 |
+
using Pointer = Element *;
|
| 1580 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 1581 |
+
|
| 1582 |
+
// Map to the underlying AffineRankN<2> layout
|
| 1583 |
+
using UnderlyingIterator = PredicatedTileAccessIterator<
|
| 1584 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
|
| 1585 |
+
layout::AffineRankN<2>, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>;
|
| 1586 |
+
|
| 1587 |
+
static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
|
| 1588 |
+
|
| 1589 |
+
/// Predicate vector stores mask to guard accesses
|
| 1590 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 1591 |
+
|
| 1592 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 1593 |
+
class Params {
|
| 1594 |
+
private:
|
| 1595 |
+
friend PredicatedTileAccessIterator;
|
| 1596 |
+
|
| 1597 |
+
/// Parameters object
|
| 1598 |
+
typename UnderlyingIterator::Params params_;
|
| 1599 |
+
|
| 1600 |
+
public:
|
| 1601 |
+
|
| 1602 |
+
/// Default constructor
|
| 1603 |
+
Params() = default;
|
| 1604 |
+
|
| 1605 |
+
/// Construct the Params object given an AffineRankN<2> tensor's layout
|
| 1606 |
+
CUTLASS_HOST_DEVICE
|
| 1607 |
+
Params(Layout const &layout)
|
| 1608 |
+
: params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){};
|
| 1609 |
+
};
|
| 1610 |
+
|
| 1611 |
+
private:
|
| 1612 |
+
//
|
| 1613 |
+
// Data members
|
| 1614 |
+
//
|
| 1615 |
+
|
| 1616 |
+
/// Underlying AffineRankN<2> tile iterator
|
| 1617 |
+
UnderlyingIterator iterator_;
|
| 1618 |
+
|
| 1619 |
+
public:
|
| 1620 |
+
|
| 1621 |
+
/// Default constructor
|
| 1622 |
+
PredicatedTileAccessIterator() = default;
|
| 1623 |
+
|
| 1624 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 1625 |
+
/// and thread ID
|
| 1626 |
+
CUTLASS_HOST_DEVICE
|
| 1627 |
+
PredicatedTileAccessIterator(
|
| 1628 |
+
///< Precomputed parameters object
|
| 1629 |
+
Params const ¶ms,
|
| 1630 |
+
///< Pointer to start of tensor
|
| 1631 |
+
Pointer pointer,
|
| 1632 |
+
///< Extent of tensor
|
| 1633 |
+
TensorCoord extent,
|
| 1634 |
+
///< ID of each participating thread
|
| 1635 |
+
int thread_id,
|
| 1636 |
+
///< Initial offset of threadblock
|
| 1637 |
+
TensorCoord const &threadblock_offset,
|
| 1638 |
+
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
| 1639 |
+
)
|
| 1640 |
+
: iterator_(params.params_, pointer,
|
| 1641 |
+
layout::PitchLinearCoord(extent.column(), extent.row()),
|
| 1642 |
+
thread_id,
|
| 1643 |
+
layout::PitchLinearCoord(threadblock_offset.column(),
|
| 1644 |
+
threadblock_offset.row())) {}
|
| 1645 |
+
|
| 1646 |
+
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
|
| 1647 |
+
CUTLASS_HOST_DEVICE
|
| 1648 |
+
PredicatedTileAccessIterator(
|
| 1649 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 1650 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 1651 |
+
TensorCoord extent, ///< Extent of tensor
|
| 1652 |
+
int thread_id ///< ID of each participating thread
|
| 1653 |
+
)
|
| 1654 |
+
: PredicatedTileAccessIterator(params, pointer, extent, thread_id,
|
| 1655 |
+
make_Coord(0, 0)) {}
|
| 1656 |
+
|
| 1657 |
+
/// Overrides the internal iteration index
|
| 1658 |
+
CUTLASS_HOST_DEVICE
|
| 1659 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 1660 |
+
|
| 1661 |
+
/// Adds a pointer offset in units of Element
|
| 1662 |
+
CUTLASS_HOST_DEVICE
|
| 1663 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1664 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 1665 |
+
}
|
| 1666 |
+
|
| 1667 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 1668 |
+
/// tiles
|
| 1669 |
+
CUTLASS_HOST_DEVICE
|
| 1670 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 1671 |
+
iterator_.add_tile_offset(make_Coord(tile_offset.column(), tile_offset.row()));
|
| 1672 |
+
}
|
| 1673 |
+
|
| 1674 |
+
/// Returns a pointer
|
| 1675 |
+
CUTLASS_HOST_DEVICE
|
| 1676 |
+
AccessType *get() const {
|
| 1677 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 1678 |
+
}
|
| 1679 |
+
|
| 1680 |
+
/// Advances to the next tile in memory.
|
| 1681 |
+
///
|
| 1682 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1683 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1684 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1685 |
+
/// pointer.
|
| 1686 |
+
CUTLASS_HOST_DEVICE
|
| 1687 |
+
PredicatedTileAccessIterator &operator++() {
|
| 1688 |
+
++iterator_;
|
| 1689 |
+
return *this;
|
| 1690 |
+
}
|
| 1691 |
+
|
| 1692 |
+
/// Advances to the next tile in memory.
|
| 1693 |
+
///
|
| 1694 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1695 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1696 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1697 |
+
/// pointer.
|
| 1698 |
+
CUTLASS_HOST_DEVICE
|
| 1699 |
+
PredicatedTileAccessIterator operator++(int) {
|
| 1700 |
+
PredicatedTileAccessIterator self(*this);
|
| 1701 |
+
operator++();
|
| 1702 |
+
return self;
|
| 1703 |
+
}
|
| 1704 |
+
|
| 1705 |
+
/// Clears the predicate set efficiently
|
| 1706 |
+
CUTLASS_HOST_DEVICE
|
| 1707 |
+
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
| 1708 |
+
|
| 1709 |
+
/// Clears the predicate set efficiently
|
| 1710 |
+
CUTLASS_HOST_DEVICE
|
| 1711 |
+
void enable_mask() { iterator_.enable_mask(); }
|
| 1712 |
+
|
| 1713 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 1714 |
+
CUTLASS_HOST_DEVICE
|
| 1715 |
+
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
| 1716 |
+
|
| 1717 |
+
/// Gets the mask
|
| 1718 |
+
CUTLASS_HOST_DEVICE
|
| 1719 |
+
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
| 1720 |
+
|
| 1721 |
+
/// Returns whether access is valid or not
|
| 1722 |
+
CUTLASS_HOST_DEVICE
|
| 1723 |
+
bool valid() {
|
| 1724 |
+
return iterator_.valid();
|
| 1725 |
+
}
|
| 1726 |
+
};
|
| 1727 |
+
|
| 1728 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1729 |
+
|
| 1730 |
+
/// Specialization of PredicatedTileAccessIterator for column-major interleaved data.
|
| 1731 |
+
/// It is mapped to the congruous layout.
|
| 1732 |
+
///
|
| 1733 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 1734 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 1735 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 1736 |
+
/// MaskedTileIteratorConcept
|
| 1737 |
+
///
|
| 1738 |
+
|
| 1739 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 1740 |
+
typename ThreadMap_, typename AccessType_, int InterleavedK>
|
| 1741 |
+
class PredicatedTileAccessIterator<Shape_, Element_,
|
| 1742 |
+
layout::ColumnMajorInterleaved<InterleavedK>,
|
| 1743 |
+
AdvanceRank, ThreadMap_, AccessType_, false,
|
| 1744 |
+
layout::NoPermute> {
|
| 1745 |
+
public:
|
| 1746 |
+
static_assert(
|
| 1747 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 1748 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 1749 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 1750 |
+
|
| 1751 |
+
using Shape = Shape_;
|
| 1752 |
+
using Element = Element_;
|
| 1753 |
+
static int const kInterleavedK = InterleavedK;
|
| 1754 |
+
using Layout = layout::ColumnMajorInterleaved<kInterleavedK>;
|
| 1755 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 1756 |
+
using ThreadMap = ThreadMap_;
|
| 1757 |
+
using AccessType = AccessType_;
|
| 1758 |
+
|
| 1759 |
+
using Index = typename Layout::Index;
|
| 1760 |
+
using LongIndex = typename Layout::LongIndex;
|
| 1761 |
+
|
| 1762 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 1763 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1764 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 1765 |
+
|
| 1766 |
+
using Pointer = Element *;
|
| 1767 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 1768 |
+
|
| 1769 |
+
using UnderlyingIterator = PredicatedTileAccessIterator<
|
| 1770 |
+
layout::PitchLinearShape<Shape::kRow * kInterleavedK,
|
| 1771 |
+
Shape::kColumn / kInterleavedK>,
|
| 1772 |
+
Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap,
|
| 1773 |
+
AccessType>;
|
| 1774 |
+
|
| 1775 |
+
static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
|
| 1776 |
+
|
| 1777 |
+
/// Predicate vector stores mask to guard accesses
|
| 1778 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 1779 |
+
|
| 1780 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 1781 |
+
class Params {
|
| 1782 |
+
private:
|
| 1783 |
+
friend PredicatedTileAccessIterator;
|
| 1784 |
+
|
| 1785 |
+
/// Parameters object
|
| 1786 |
+
typename UnderlyingIterator::Params params_;
|
| 1787 |
+
|
| 1788 |
+
public:
|
| 1789 |
+
|
| 1790 |
+
/// Default constructor
|
| 1791 |
+
Params() = default;
|
| 1792 |
+
|
| 1793 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 1794 |
+
CUTLASS_HOST_DEVICE
|
| 1795 |
+
Params(Layout const &layout)
|
| 1796 |
+
: params_(layout::PitchLinear(layout.stride(0))) {}
|
| 1797 |
+
|
| 1798 |
+
CUTLASS_HOST_DEVICE
|
| 1799 |
+
Params(typename UnderlyingIterator::Params::Base const &base)
|
| 1800 |
+
: params_(base) {}
|
| 1801 |
+
};
|
| 1802 |
+
|
| 1803 |
+
private:
|
| 1804 |
+
//
|
| 1805 |
+
// Data members
|
| 1806 |
+
//
|
| 1807 |
+
|
| 1808 |
+
/// Underlying pitch-linear tile iterator
|
| 1809 |
+
UnderlyingIterator iterator_;
|
| 1810 |
+
|
| 1811 |
+
public:
|
| 1812 |
+
|
| 1813 |
+
/// Default constructor
|
| 1814 |
+
PredicatedTileAccessIterator() = default;
|
| 1815 |
+
|
| 1816 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 1817 |
+
/// and thread ID
|
| 1818 |
+
CUTLASS_HOST_DEVICE
|
| 1819 |
+
PredicatedTileAccessIterator(
|
| 1820 |
+
/// Precomputed parameters object
|
| 1821 |
+
Params const ¶ms,
|
| 1822 |
+
/// Pointer to start of tensor
|
| 1823 |
+
Pointer pointer,
|
| 1824 |
+
/// Extent of tensor
|
| 1825 |
+
TensorCoord extent,
|
| 1826 |
+
/// ID of each participating thread
|
| 1827 |
+
int thread_id,
|
| 1828 |
+
/// Initial offset of threadblock
|
| 1829 |
+
TensorCoord const &threadblock_offset,
|
| 1830 |
+
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
| 1831 |
+
)
|
| 1832 |
+
: iterator_(params.params_, pointer,
|
| 1833 |
+
layout::PitchLinearCoord(extent.row() * kInterleavedK,
|
| 1834 |
+
extent.column() / kInterleavedK),
|
| 1835 |
+
thread_id,
|
| 1836 |
+
layout::PitchLinearCoord(
|
| 1837 |
+
threadblock_offset.row() * kInterleavedK,
|
| 1838 |
+
threadblock_offset.column() / kInterleavedK)) {}
|
| 1839 |
+
|
| 1840 |
+
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
|
| 1841 |
+
CUTLASS_HOST_DEVICE
|
| 1842 |
+
PredicatedTileAccessIterator(
|
| 1843 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 1844 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 1845 |
+
TensorCoord extent, ///< Extent of tensor
|
| 1846 |
+
int thread_id ///< ID of each participating thread
|
| 1847 |
+
)
|
| 1848 |
+
: PredicatedTileAccessIterator(params, pointer, extent, thread_id,
|
| 1849 |
+
make_Coord(0, 0)) {}
|
| 1850 |
+
|
| 1851 |
+
/// Overrides the internal iteration index
|
| 1852 |
+
CUTLASS_HOST_DEVICE
|
| 1853 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 1854 |
+
|
| 1855 |
+
/// Adds a pointer offset in units of Element
|
| 1856 |
+
CUTLASS_HOST_DEVICE
|
| 1857 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1858 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 1859 |
+
}
|
| 1860 |
+
|
| 1861 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 1862 |
+
/// tiles
|
| 1863 |
+
CUTLASS_HOST_DEVICE
|
| 1864 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 1865 |
+
iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
|
| 1866 |
+
}
|
| 1867 |
+
|
| 1868 |
+
/// Returns a pointer
|
| 1869 |
+
CUTLASS_HOST_DEVICE
|
| 1870 |
+
AccessType *get() const {
|
| 1871 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 1872 |
+
}
|
| 1873 |
+
|
| 1874 |
+
/// Advances to the next tile in memory.
|
| 1875 |
+
///
|
| 1876 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1877 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1878 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1879 |
+
/// pointer.
|
| 1880 |
+
CUTLASS_HOST_DEVICE
|
| 1881 |
+
PredicatedTileAccessIterator &operator++() {
|
| 1882 |
+
++iterator_;
|
| 1883 |
+
return *this;
|
| 1884 |
+
}
|
| 1885 |
+
|
| 1886 |
+
/// Advances to the next tile in memory.
|
| 1887 |
+
///
|
| 1888 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1889 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1890 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1891 |
+
/// pointer.
|
| 1892 |
+
CUTLASS_HOST_DEVICE
|
| 1893 |
+
PredicatedTileAccessIterator operator++(int) {
|
| 1894 |
+
PredicatedTileAccessIterator self(*this);
|
| 1895 |
+
operator++();
|
| 1896 |
+
return self;
|
| 1897 |
+
}
|
| 1898 |
+
|
| 1899 |
+
/// Clears the predicate set efficiently
|
| 1900 |
+
CUTLASS_HOST_DEVICE
|
| 1901 |
+
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
| 1902 |
+
|
| 1903 |
+
/// Clears the predicate set efficiently
|
| 1904 |
+
CUTLASS_HOST_DEVICE
|
| 1905 |
+
void enable_mask() { iterator_.enable_mask(); }
|
| 1906 |
+
|
| 1907 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 1908 |
+
CUTLASS_HOST_DEVICE
|
| 1909 |
+
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
| 1910 |
+
|
| 1911 |
+
/// Gets the mask
|
| 1912 |
+
CUTLASS_HOST_DEVICE
|
| 1913 |
+
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
| 1914 |
+
|
| 1915 |
+
/// Returns whether access is valid or not
|
| 1916 |
+
CUTLASS_HOST_DEVICE
|
| 1917 |
+
bool valid() { return iterator_.valid(); }
|
| 1918 |
+
};
|
| 1919 |
+
|
| 1920 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1921 |
+
|
| 1922 |
+
/// Specialization of PredicatedTileAccessIterator for row-major interleaved data.
|
| 1923 |
+
// It is mapped to the congruous layout.
|
| 1924 |
+
///
|
| 1925 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 1926 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 1927 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 1928 |
+
/// MaskedTileIteratorConcept
|
| 1929 |
+
///
|
| 1930 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 1931 |
+
typename ThreadMap_, typename AccessType_, int InterleavedK>
|
| 1932 |
+
class PredicatedTileAccessIterator<Shape_, Element_,
|
| 1933 |
+
layout::RowMajorInterleaved<InterleavedK>,
|
| 1934 |
+
AdvanceRank, ThreadMap_, AccessType_, false,
|
| 1935 |
+
layout::NoPermute> {
|
| 1936 |
+
public:
|
| 1937 |
+
static_assert(
|
| 1938 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 1939 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 1940 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 1941 |
+
|
| 1942 |
+
using Shape = Shape_;
|
| 1943 |
+
using Element = Element_;
|
| 1944 |
+
static int const kInterleavedK = InterleavedK;
|
| 1945 |
+
using Layout = layout::RowMajorInterleaved<kInterleavedK>;
|
| 1946 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 1947 |
+
using ThreadMap = ThreadMap_;
|
| 1948 |
+
using AccessType = AccessType_;
|
| 1949 |
+
|
| 1950 |
+
using Index = typename Layout::Index;
|
| 1951 |
+
using LongIndex = typename Layout::LongIndex;
|
| 1952 |
+
|
| 1953 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 1954 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1955 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 1956 |
+
|
| 1957 |
+
using Pointer = Element *;
|
| 1958 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 1959 |
+
|
| 1960 |
+
using UnderlyingIterator = PredicatedTileAccessIterator<
|
| 1961 |
+
layout::PitchLinearShape<Shape::kColumn * kInterleavedK,
|
| 1962 |
+
Shape::kRow / kInterleavedK>,
|
| 1963 |
+
Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap,
|
| 1964 |
+
AccessType>;
|
| 1965 |
+
|
| 1966 |
+
|
| 1967 |
+
static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
|
| 1968 |
+
|
| 1969 |
+
/// Predicate vector stores mask to guard accesses
|
| 1970 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 1971 |
+
|
| 1972 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 1973 |
+
class Params {
|
| 1974 |
+
private:
|
| 1975 |
+
friend PredicatedTileAccessIterator;
|
| 1976 |
+
|
| 1977 |
+
/// Parameters object
|
| 1978 |
+
typename UnderlyingIterator::Params params_;
|
| 1979 |
+
|
| 1980 |
+
public:
|
| 1981 |
+
|
| 1982 |
+
/// Default constructor
|
| 1983 |
+
Params() = default;
|
| 1984 |
+
|
| 1985 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 1986 |
+
CUTLASS_HOST_DEVICE
|
| 1987 |
+
Params(Layout const &layout)
|
| 1988 |
+
: params_(layout::PitchLinear(layout.stride(0))) {}
|
| 1989 |
+
|
| 1990 |
+
CUTLASS_HOST_DEVICE
|
| 1991 |
+
Params(typename UnderlyingIterator::Params::Base const &base)
|
| 1992 |
+
: params_(base) {}
|
| 1993 |
+
};
|
| 1994 |
+
|
| 1995 |
+
private:
|
| 1996 |
+
//
|
| 1997 |
+
// Data members
|
| 1998 |
+
//
|
| 1999 |
+
|
| 2000 |
+
/// Underlying pitch-linear tile iterator
|
| 2001 |
+
UnderlyingIterator iterator_;
|
| 2002 |
+
|
| 2003 |
+
public:
|
| 2004 |
+
|
| 2005 |
+
/// Default constructor
|
| 2006 |
+
PredicatedTileAccessIterator() = default;
|
| 2007 |
+
|
| 2008 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 2009 |
+
/// and thread ID
|
| 2010 |
+
CUTLASS_HOST_DEVICE
|
| 2011 |
+
PredicatedTileAccessIterator(
|
| 2012 |
+
/// Precomputed parameters object
|
| 2013 |
+
Params const ¶ms,
|
| 2014 |
+
/// Pointer to start of tensor
|
| 2015 |
+
Pointer pointer,
|
| 2016 |
+
/// Extent of tensor
|
| 2017 |
+
TensorCoord extent,
|
| 2018 |
+
/// ID of each participating thread
|
| 2019 |
+
int thread_id,
|
| 2020 |
+
/// Initial offset of threadblock
|
| 2021 |
+
TensorCoord const &threadblock_offset,
|
| 2022 |
+
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
| 2023 |
+
)
|
| 2024 |
+
: iterator_(params.params_, pointer,
|
| 2025 |
+
layout::PitchLinearCoord(extent.column() * kInterleavedK,
|
| 2026 |
+
extent.row() / kInterleavedK),
|
| 2027 |
+
thread_id,
|
| 2028 |
+
layout::PitchLinearCoord(
|
| 2029 |
+
threadblock_offset.column() * kInterleavedK,
|
| 2030 |
+
threadblock_offset.row() / kInterleavedK)) {}
|
| 2031 |
+
|
| 2032 |
+
/// Construct a PredicatedTileAccessIterator with zero threadblock offset
|
| 2033 |
+
CUTLASS_HOST_DEVICE
|
| 2034 |
+
PredicatedTileAccessIterator(
|
| 2035 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 2036 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 2037 |
+
TensorCoord extent, ///< Extent of tensor
|
| 2038 |
+
int thread_id ///< ID of each participating thread
|
| 2039 |
+
)
|
| 2040 |
+
: PredicatedTileAccessIterator(params, pointer, extent, thread_id,
|
| 2041 |
+
make_Coord(0, 0)) {}
|
| 2042 |
+
|
| 2043 |
+
/// Overrides the internal iteration index
|
| 2044 |
+
CUTLASS_HOST_DEVICE
|
| 2045 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 2046 |
+
|
| 2047 |
+
/// Adds a pointer offset in units of Element
|
| 2048 |
+
CUTLASS_HOST_DEVICE
|
| 2049 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 2050 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 2051 |
+
}
|
| 2052 |
+
|
| 2053 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 2054 |
+
/// tiles
|
| 2055 |
+
CUTLASS_HOST_DEVICE
|
| 2056 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 2057 |
+
iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
|
| 2058 |
+
}
|
| 2059 |
+
|
| 2060 |
+
/// Returns a pointer
|
| 2061 |
+
CUTLASS_HOST_DEVICE
|
| 2062 |
+
AccessType *get() const {
|
| 2063 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 2064 |
+
}
|
| 2065 |
+
|
| 2066 |
+
/// Advances to the next tile in memory.
|
| 2067 |
+
///
|
| 2068 |
+
/// The first time this method is called, predicates are updated, and the
|
| 2069 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 2070 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 2071 |
+
/// pointer.
|
| 2072 |
+
CUTLASS_HOST_DEVICE
|
| 2073 |
+
PredicatedTileAccessIterator &operator++() {
|
| 2074 |
+
++iterator_;
|
| 2075 |
+
return *this;
|
| 2076 |
+
}
|
| 2077 |
+
|
| 2078 |
+
/// Advances to the next tile in memory.
|
| 2079 |
+
///
|
| 2080 |
+
/// The first time this method is called, predicates are updated, and the
|
| 2081 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 2082 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 2083 |
+
/// pointer.
|
| 2084 |
+
CUTLASS_HOST_DEVICE
|
| 2085 |
+
PredicatedTileAccessIterator operator++(int) {
|
| 2086 |
+
PredicatedTileAccessIterator self(*this);
|
| 2087 |
+
operator++();
|
| 2088 |
+
return self;
|
| 2089 |
+
}
|
| 2090 |
+
|
| 2091 |
+
/// Clears the predicate set efficiently
|
| 2092 |
+
CUTLASS_HOST_DEVICE
|
| 2093 |
+
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
| 2094 |
+
|
| 2095 |
+
/// Clears the predicate set efficiently
|
| 2096 |
+
CUTLASS_HOST_DEVICE
|
| 2097 |
+
void enable_mask() { iterator_.enable_mask(); }
|
| 2098 |
+
|
| 2099 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 2100 |
+
CUTLASS_HOST_DEVICE
|
| 2101 |
+
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
| 2102 |
+
|
| 2103 |
+
/// Gets the mask
|
| 2104 |
+
CUTLASS_HOST_DEVICE
|
| 2105 |
+
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
| 2106 |
+
|
| 2107 |
+
/// Returns whether access is valid or not
|
| 2108 |
+
CUTLASS_HOST_DEVICE
|
| 2109 |
+
bool valid() { return iterator_.valid(); }
|
| 2110 |
+
};
|
| 2111 |
+
|
| 2112 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 2113 |
+
|
| 2114 |
+
} // namespace threadblock
|
| 2115 |
+
} // namespace transform
|
| 2116 |
+
} // namespace cutlass
|
| 2117 |
+
|
| 2118 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h
ADDED
|
@@ -0,0 +1,834 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates calculating the address and predicates to the load of tiles
|
| 33 |
+
from pitch-linear rank=2 tensors.
|
| 34 |
+
|
| 35 |
+
This iterator uses masks to guard out-of-bounds accesses and visits the last
|
| 36 |
+
"residue" tile first, with the objective of minimizing predicate mask updates
|
| 37 |
+
during steady-state operation.
|
| 38 |
+
|
| 39 |
+
A precomputed "Params" object minimizes the amount of state that must be
|
| 40 |
+
stored in registers, and integer addition is used to advance the pointer
|
| 41 |
+
through memory.
|
| 42 |
+
*/
|
| 43 |
+
|
| 44 |
+
#pragma once
|
| 45 |
+
|
| 46 |
+
#include "cutlass/array.h"
|
| 47 |
+
#include "cutlass/coord.h"
|
| 48 |
+
#include "cutlass/cutlass.h"
|
| 49 |
+
#include "cutlass/layout/matrix.h"
|
| 50 |
+
#include "cutlass/layout/pitch_linear.h"
|
| 51 |
+
#include "cutlass/matrix_shape.h"
|
| 52 |
+
#include "cutlass/predicate_vector.h"
|
| 53 |
+
#include "cutlass/tensor_ref.h"
|
| 54 |
+
#include "cutlass/tensor_view.h"
|
| 55 |
+
#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h"
|
| 56 |
+
|
| 57 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 58 |
+
|
| 59 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 60 |
+
|
| 61 |
+
namespace cutlass {
|
| 62 |
+
namespace transform {
|
| 63 |
+
namespace threadblock {
|
| 64 |
+
|
| 65 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 66 |
+
|
| 67 |
+
/// PredicatedTileAccessIterator2dThreadTile
|
| 68 |
+
///
|
| 69 |
+
template <typename Shape, typename Element, typename Layout, int AdvanceRank,
|
| 70 |
+
typename ThreadMap, typename AccessType>
|
| 71 |
+
class PredicatedTileAccessIterator2dThreadTile;
|
| 72 |
+
|
| 73 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 74 |
+
|
| 75 |
+
/// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data.
|
| 76 |
+
///
|
| 77 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 78 |
+
typename ThreadMap_, typename AccessType_>
|
| 79 |
+
class PredicatedTileAccessIterator2dThreadTile<Shape_, Element_, layout::PitchLinear,
|
| 80 |
+
AdvanceRank, ThreadMap_, AccessType_> {
|
| 81 |
+
public:
|
| 82 |
+
static_assert(
|
| 83 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 84 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 85 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 86 |
+
|
| 87 |
+
using Shape = Shape_;
|
| 88 |
+
using Element = Element_;
|
| 89 |
+
using Layout = layout::PitchLinear;
|
| 90 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 91 |
+
using ThreadMap = ThreadMap_;
|
| 92 |
+
using AccessType = AccessType_;
|
| 93 |
+
|
| 94 |
+
using Index = typename Layout::Index;
|
| 95 |
+
using LongIndex = typename Layout::LongIndex;
|
| 96 |
+
using StrideIndex = typename Layout::Stride::Index;
|
| 97 |
+
|
| 98 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 99 |
+
using TensorView = TensorView<Element, Layout>;
|
| 100 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 101 |
+
|
| 102 |
+
using Pointer = Element *;
|
| 103 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 104 |
+
|
| 105 |
+
static int const kPredicatesPerByte = 4;
|
| 106 |
+
static int const kPredicatesPerWord = 4 * kPredicatesPerByte;
|
| 107 |
+
|
| 108 |
+
/// Number of 32b words containing predicates
|
| 109 |
+
static int const kPredicateByteCount = (ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kStrided + kPredicatesPerByte - 1) / kPredicatesPerByte;
|
| 110 |
+
static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4;
|
| 111 |
+
|
| 112 |
+
static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u;
|
| 113 |
+
|
| 114 |
+
static_assert(kPredicateWordCount <= 4, "Too many predicates.");
|
| 115 |
+
|
| 116 |
+
/// Predicate vector stores mask to guard accesses
|
| 117 |
+
using Mask = Array<uint32_t, kPredicateWordCount>;
|
| 118 |
+
|
| 119 |
+
/// Uses a non-template class
|
| 120 |
+
struct Params : PredicatedTileAccessIteratorParams {
|
| 121 |
+
|
| 122 |
+
public:
|
| 123 |
+
friend PredicatedTileAccessIterator2dThreadTile;
|
| 124 |
+
|
| 125 |
+
using Base = PredicatedTileAccessIteratorParams;
|
| 126 |
+
|
| 127 |
+
// Default ctor
|
| 128 |
+
CUTLASS_HOST_DEVICE
|
| 129 |
+
Params() { }
|
| 130 |
+
|
| 131 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 132 |
+
CUTLASS_HOST_DEVICE
|
| 133 |
+
Params(Layout const &layout) :
|
| 134 |
+
Base(layout.stride(0),
|
| 135 |
+
MakePredicatedTileAccessIteratorDesc<Shape, Element, Layout, kAdvanceRank, ThreadMap>()()
|
| 136 |
+
) { }
|
| 137 |
+
|
| 138 |
+
CUTLASS_HOST_DEVICE
|
| 139 |
+
Params(Base const &base) :
|
| 140 |
+
Base(base) { }
|
| 141 |
+
};
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
private:
|
| 145 |
+
/// Internal pointer type permits fast address arithmetic
|
| 146 |
+
using BytePointer = char *;
|
| 147 |
+
|
| 148 |
+
private:
|
| 149 |
+
//
|
| 150 |
+
// Data members
|
| 151 |
+
//
|
| 152 |
+
|
| 153 |
+
/// Parameters object with precomputed internal state
|
| 154 |
+
Params const ¶ms_;
|
| 155 |
+
|
| 156 |
+
/// Internal pointer to first access of tile
|
| 157 |
+
BytePointer pointer_;
|
| 158 |
+
|
| 159 |
+
/// Guard predicates
|
| 160 |
+
uint32_t predicates_[kPredicateWordCount];
|
| 161 |
+
|
| 162 |
+
/// Size of tensor
|
| 163 |
+
TensorCoord extent_;
|
| 164 |
+
|
| 165 |
+
/// Initial offset for each thread
|
| 166 |
+
TensorCoord thread_offset_;
|
| 167 |
+
|
| 168 |
+
/// Index of residue tile
|
| 169 |
+
int residue_tile_idx_;
|
| 170 |
+
|
| 171 |
+
/// Used for out-of-order visitation
|
| 172 |
+
bool is_residue_tile_;
|
| 173 |
+
|
| 174 |
+
/// Iteration in the contiguous dimension
|
| 175 |
+
int iteration_contiguous_;
|
| 176 |
+
|
| 177 |
+
/// Iteration in the strided dimension
|
| 178 |
+
int iteration_strided_;
|
| 179 |
+
|
| 180 |
+
/// Tracks iterations within the thread loop
|
| 181 |
+
int iteration_thread_;
|
| 182 |
+
|
| 183 |
+
private:
|
| 184 |
+
/// Computes predicates based on internally tracked per-thread offset.
|
| 185 |
+
CUTLASS_HOST_DEVICE
|
| 186 |
+
void compute_predicates_(
|
| 187 |
+
/// optionally, simplify predicate calculation during 'steady state' phase
|
| 188 |
+
bool is_steady_state = false) {
|
| 189 |
+
|
| 190 |
+
CUTLASS_PRAGMA_UNROLL
|
| 191 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 192 |
+
predicates_[i] = 0u;
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
CUTLASS_PRAGMA_UNROLL
|
| 196 |
+
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
| 197 |
+
CUTLASS_PRAGMA_UNROLL
|
| 198 |
+
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
| 199 |
+
CUTLASS_PRAGMA_UNROLL
|
| 200 |
+
for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++) {
|
| 201 |
+
|
| 202 |
+
TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous,
|
| 203 |
+
ts + s * ThreadMap::Delta::kStrided);
|
| 204 |
+
|
| 205 |
+
TensorCoord coord = thread_offset_ + iteration_coord;
|
| 206 |
+
|
| 207 |
+
bool guard;
|
| 208 |
+
|
| 209 |
+
if (is_steady_state) {
|
| 210 |
+
if (kAdvanceRank == 0) {
|
| 211 |
+
guard = (coord.strided() < extent_.strided());
|
| 212 |
+
} else {
|
| 213 |
+
guard = (coord.contiguous() < extent_.contiguous());
|
| 214 |
+
}
|
| 215 |
+
} else {
|
| 216 |
+
guard = (coord.strided() < extent_.strided() &&
|
| 217 |
+
coord.contiguous() < extent_.contiguous());
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
int pred_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided;
|
| 221 |
+
int word_idx = pred_idx / kPredicatesPerWord;
|
| 222 |
+
int residual = pred_idx % kPredicatesPerWord;
|
| 223 |
+
int byte_idx = residual / kPredicatesPerByte;
|
| 224 |
+
int bit_idx = residual % kPredicatesPerByte;
|
| 225 |
+
|
| 226 |
+
predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx));
|
| 227 |
+
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
public:
|
| 235 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 236 |
+
/// and thread ID
|
| 237 |
+
CUTLASS_HOST_DEVICE
|
| 238 |
+
PredicatedTileAccessIterator2dThreadTile(
|
| 239 |
+
/// Precomputed parameters object
|
| 240 |
+
Params const ¶ms,
|
| 241 |
+
/// Pointer to start of tensor
|
| 242 |
+
Pointer pointer,
|
| 243 |
+
/// Extent of tensor
|
| 244 |
+
TensorCoord extent,
|
| 245 |
+
/// ID of each participating thread
|
| 246 |
+
int thread_id,
|
| 247 |
+
/// Initial offset of threadblock
|
| 248 |
+
TensorCoord const &threadblock_offset)
|
| 249 |
+
: params_(params),
|
| 250 |
+
pointer_(reinterpret_cast<BytePointer>(
|
| 251 |
+
const_cast<NonConstPointer>(pointer))),
|
| 252 |
+
extent_(extent),
|
| 253 |
+
is_residue_tile_(true) {
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
TensorCoord residue_offset;
|
| 257 |
+
if (kAdvanceRank) {
|
| 258 |
+
residue_tile_idx_ =
|
| 259 |
+
(extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) /
|
| 260 |
+
Shape::kStrided;
|
| 261 |
+
residue_offset = make_Coord(0, residue_tile_idx_ * Shape::kStrided);
|
| 262 |
+
} else {
|
| 263 |
+
residue_tile_idx_ =
|
| 264 |
+
(extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) /
|
| 265 |
+
Shape::kContiguous;
|
| 266 |
+
residue_offset = make_Coord(residue_tile_idx_ * Shape::kContiguous, 0);
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
// Per-thread offset in logical coordinates of tensor
|
| 270 |
+
thread_offset_ = threadblock_offset + residue_offset +
|
| 271 |
+
ThreadMap::initial_offset(thread_id);
|
| 272 |
+
|
| 273 |
+
// update internal pointers
|
| 274 |
+
Layout layout(params_.stride_);
|
| 275 |
+
add_pointer_offset(layout(thread_offset_));
|
| 276 |
+
|
| 277 |
+
compute_predicates_(false);
|
| 278 |
+
|
| 279 |
+
set_iteration_index(0);
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
/// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset
|
| 283 |
+
CUTLASS_HOST_DEVICE
|
| 284 |
+
PredicatedTileAccessIterator2dThreadTile(
|
| 285 |
+
/// Precomputed parameters object
|
| 286 |
+
Params const ¶ms,
|
| 287 |
+
/// Pointer to start of tensor
|
| 288 |
+
Pointer pointer,
|
| 289 |
+
/// Extent of tensor
|
| 290 |
+
TensorCoord extent,
|
| 291 |
+
///< ID of each participating thread
|
| 292 |
+
int thread_id)
|
| 293 |
+
: PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id,
|
| 294 |
+
make_Coord(0, 0)) {}
|
| 295 |
+
|
| 296 |
+
/// Overrides the internal iteration index
|
| 297 |
+
CUTLASS_HOST_DEVICE
|
| 298 |
+
void set_iteration_index(int index) {
|
| 299 |
+
|
| 300 |
+
int residual = index % (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided);
|
| 301 |
+
iteration_strided_ = index / (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided);
|
| 302 |
+
|
| 303 |
+
iteration_contiguous_ = residual / ThreadMap::ThreadAccessShape::kStrided;
|
| 304 |
+
iteration_thread_ = residual % ThreadMap::ThreadAccessShape::kStrided;
|
| 305 |
+
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
/// Adds a pointer offset in units of Element
|
| 309 |
+
CUTLASS_HOST_DEVICE
|
| 310 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 311 |
+
pointer_ += int(sizeof(Element)) * pointer_offset;
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
|
| 315 |
+
CUTLASS_DEVICE
|
| 316 |
+
void add_tile_offset(
|
| 317 |
+
TensorCoord const &tile_offset) {
|
| 318 |
+
if (is_residue_tile_) {
|
| 319 |
+
TensorCoord residue_offset;
|
| 320 |
+
if (kAdvanceRank) {
|
| 321 |
+
residue_offset = TensorCoord(0, residue_tile_idx_ * Shape::kStrided);
|
| 322 |
+
} else {
|
| 323 |
+
residue_offset = TensorCoord(residue_tile_idx_ * Shape::kContiguous, 0);
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
thread_offset_ -= residue_offset;
|
| 327 |
+
|
| 328 |
+
Layout layout(params_.stride_);
|
| 329 |
+
add_pointer_offset(-layout(residue_offset));
|
| 330 |
+
|
| 331 |
+
compute_predicates_(true);
|
| 332 |
+
|
| 333 |
+
if (kAdvanceRank) {
|
| 334 |
+
pointer_ += params_.inc_advance_ * (tile_offset.strided() - 1);
|
| 335 |
+
pointer_ += Shape::kContiguous * tile_offset.contiguous();
|
| 336 |
+
} else {
|
| 337 |
+
pointer_ += params_.inc_advance_ * (tile_offset.contiguous() - 1);
|
| 338 |
+
pointer_ += Shape::kStrided * tile_offset.strided();
|
| 339 |
+
}
|
| 340 |
+
} else {
|
| 341 |
+
if (kAdvanceRank) {
|
| 342 |
+
pointer_ += params_.inc_advance_ * tile_offset.strided();
|
| 343 |
+
pointer_ += Shape::kContiguous * tile_offset.contiguous();
|
| 344 |
+
} else {
|
| 345 |
+
pointer_ += params_.inc_advance_ * tile_offset.contiguous();
|
| 346 |
+
pointer_ += Shape::kStrided * tile_offset.strided();
|
| 347 |
+
}
|
| 348 |
+
}
|
| 349 |
+
is_residue_tile_ = false;
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
CUTLASS_HOST_DEVICE
|
| 353 |
+
AccessType *get() const {
|
| 354 |
+
|
| 355 |
+
AccessType *ret_val = reinterpret_cast<AccessType *>(
|
| 356 |
+
pointer_ + (iteration_thread_ * params_.stride_ + iteration_contiguous_ * ThreadMap::Delta::kContiguous) * int(sizeof(Element)));
|
| 357 |
+
|
| 358 |
+
return ret_val;
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
/// Increment and return an instance to self.
|
| 362 |
+
CUTLASS_HOST_DEVICE
|
| 363 |
+
PredicatedTileAccessIterator2dThreadTile &operator++() {
|
| 364 |
+
|
| 365 |
+
iteration_thread_++;
|
| 366 |
+
|
| 367 |
+
if (iteration_thread_ < ThreadMap::ThreadAccessShape::kStrided)
|
| 368 |
+
return *this;
|
| 369 |
+
|
| 370 |
+
iteration_thread_ = 0;
|
| 371 |
+
|
| 372 |
+
++iteration_contiguous_;
|
| 373 |
+
|
| 374 |
+
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
|
| 375 |
+
return *this;
|
| 376 |
+
|
| 377 |
+
// Enter here only if (iteration_contiguous_ ==
|
| 378 |
+
// ThreadMap::Iteration::kContiguous)
|
| 379 |
+
iteration_contiguous_ = 0;
|
| 380 |
+
++iteration_strided_;
|
| 381 |
+
|
| 382 |
+
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
| 383 |
+
pointer_ += params_.inc_strided_;
|
| 384 |
+
return *this;
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
// Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
|
| 388 |
+
// which means we enter the next tile.
|
| 389 |
+
iteration_strided_ = 0;
|
| 390 |
+
|
| 391 |
+
// advance to next tile
|
| 392 |
+
pointer_ += params_.inc_next_;
|
| 393 |
+
|
| 394 |
+
// now return to start tile - if the iterator is subsequently advanced, this
|
| 395 |
+
// subtraction as well as the subsequent integer addition are both elided by
|
| 396 |
+
// the compiler.
|
| 397 |
+
pointer_ -= params_.inc_advance_;
|
| 398 |
+
|
| 399 |
+
return *this;
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
/// Increment and return an instance to self.
|
| 403 |
+
CUTLASS_HOST_DEVICE
|
| 404 |
+
PredicatedTileAccessIterator2dThreadTile operator++(int) {
|
| 405 |
+
PredicatedTileAccessIterator2dThreadTile self(*this);
|
| 406 |
+
operator++();
|
| 407 |
+
return self;
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
/// Clears the predicate set efficiently
|
| 411 |
+
CUTLASS_HOST_DEVICE
|
| 412 |
+
void clear_mask(bool enable = true) {
|
| 413 |
+
CUTLASS_PRAGMA_UNROLL
|
| 414 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 415 |
+
predicates_[i] = enable ? 0u : predicates_[i];
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
/// Clears the predicate set efficiently
|
| 421 |
+
CUTLASS_HOST_DEVICE
|
| 422 |
+
void enable_mask() {
|
| 423 |
+
CUTLASS_PRAGMA_UNROLL
|
| 424 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 425 |
+
predicates_[i] = 0xffffffff;
|
| 426 |
+
}
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 430 |
+
CUTLASS_HOST_DEVICE
|
| 431 |
+
void set_mask(Mask const &mask) {
|
| 432 |
+
CUTLASS_PRAGMA_UNROLL
|
| 433 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 434 |
+
predicates_[i] = mask[i];
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
/// Gets the mask
|
| 440 |
+
CUTLASS_HOST_DEVICE
|
| 441 |
+
void get_mask(Mask &mask) {
|
| 442 |
+
CUTLASS_PRAGMA_UNROLL
|
| 443 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 444 |
+
mask[i] = predicates_[i];
|
| 445 |
+
}
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
/// Returns whether access is valid or not
|
| 449 |
+
CUTLASS_HOST_DEVICE
|
| 450 |
+
bool valid() {
|
| 451 |
+
|
| 452 |
+
int pred_idx =
|
| 453 |
+
iteration_thread_ +
|
| 454 |
+
iteration_contiguous_ * ThreadMap::ThreadAccessShape::kStrided +
|
| 455 |
+
iteration_strided_ * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided;
|
| 456 |
+
|
| 457 |
+
int word_idx = pred_idx / kPredicatesPerWord;
|
| 458 |
+
int residual = pred_idx % kPredicatesPerWord;
|
| 459 |
+
int byte_idx = residual / kPredicatesPerByte;
|
| 460 |
+
int bit_idx = residual % kPredicatesPerByte;
|
| 461 |
+
|
| 462 |
+
bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0;
|
| 463 |
+
|
| 464 |
+
return pred;
|
| 465 |
+
}
|
| 466 |
+
};
|
| 467 |
+
|
| 468 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 469 |
+
|
| 470 |
+
/// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data.
|
| 471 |
+
///
|
| 472 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 473 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 474 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 475 |
+
/// MaskedTileIteratorConcept
|
| 476 |
+
///
|
| 477 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 478 |
+
typename ThreadMap_, typename AccessType_>
|
| 479 |
+
class PredicatedTileAccessIterator2dThreadTile<Shape_, Element_, layout::ColumnMajor,
|
| 480 |
+
AdvanceRank, ThreadMap_, AccessType_> {
|
| 481 |
+
public:
|
| 482 |
+
static_assert(
|
| 483 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 484 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 485 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 486 |
+
|
| 487 |
+
using Shape = Shape_;
|
| 488 |
+
using Element = Element_;
|
| 489 |
+
using Layout = layout::ColumnMajor;
|
| 490 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 491 |
+
using ThreadMap = ThreadMap_;
|
| 492 |
+
using AccessType = AccessType_;
|
| 493 |
+
|
| 494 |
+
using Index = typename Layout::Index;
|
| 495 |
+
using LongIndex = typename Layout::LongIndex;
|
| 496 |
+
|
| 497 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 498 |
+
using TensorView = TensorView<Element, Layout>;
|
| 499 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 500 |
+
|
| 501 |
+
using Pointer = Element *;
|
| 502 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 503 |
+
|
| 504 |
+
using UnderlyingIterator = PredicatedTileAccessIterator2dThreadTile<
|
| 505 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
|
| 506 |
+
layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>;
|
| 507 |
+
|
| 508 |
+
/// Predicate vector stores mask to guard accesses
|
| 509 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 510 |
+
|
| 511 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 512 |
+
class Params {
|
| 513 |
+
private:
|
| 514 |
+
friend PredicatedTileAccessIterator2dThreadTile;
|
| 515 |
+
|
| 516 |
+
/// Parameters object
|
| 517 |
+
typename UnderlyingIterator::Params params_;
|
| 518 |
+
|
| 519 |
+
public:
|
| 520 |
+
|
| 521 |
+
/// Default ctor
|
| 522 |
+
CUTLASS_HOST_DEVICE
|
| 523 |
+
Params() { }
|
| 524 |
+
|
| 525 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 526 |
+
CUTLASS_HOST_DEVICE
|
| 527 |
+
Params(Layout const &layout)
|
| 528 |
+
: params_(layout::PitchLinear(layout.stride(0))){}
|
| 529 |
+
|
| 530 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 531 |
+
CUTLASS_HOST_DEVICE
|
| 532 |
+
Params(typename UnderlyingIterator::Params::Base const &base)
|
| 533 |
+
: params_(base) {}
|
| 534 |
+
};
|
| 535 |
+
|
| 536 |
+
private:
|
| 537 |
+
//
|
| 538 |
+
// Data members
|
| 539 |
+
//
|
| 540 |
+
|
| 541 |
+
/// Underlying pitch-linear tile iterator
|
| 542 |
+
UnderlyingIterator iterator_;
|
| 543 |
+
|
| 544 |
+
public:
|
| 545 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 546 |
+
/// and thread ID
|
| 547 |
+
CUTLASS_HOST_DEVICE
|
| 548 |
+
PredicatedTileAccessIterator2dThreadTile(
|
| 549 |
+
///< Precomputed parameters object
|
| 550 |
+
Params const ¶ms,
|
| 551 |
+
///< Pointer to start of tensor
|
| 552 |
+
Pointer pointer,
|
| 553 |
+
///< Extent of tensor
|
| 554 |
+
TensorCoord extent,
|
| 555 |
+
///< ID of each participating thread
|
| 556 |
+
int thread_id,
|
| 557 |
+
///< Initial offset of threadblock
|
| 558 |
+
TensorCoord const &threadblock_offset)
|
| 559 |
+
: iterator_(params.params_, pointer,
|
| 560 |
+
layout::PitchLinearCoord(extent.row(), extent.column()),
|
| 561 |
+
thread_id,
|
| 562 |
+
layout::PitchLinearCoord(threadblock_offset.row(),
|
| 563 |
+
threadblock_offset.column())) {}
|
| 564 |
+
|
| 565 |
+
/// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset
|
| 566 |
+
CUTLASS_HOST_DEVICE
|
| 567 |
+
PredicatedTileAccessIterator2dThreadTile(
|
| 568 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 569 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 570 |
+
TensorCoord extent, ///< Extent of tensor
|
| 571 |
+
int thread_id ///< ID of each participating thread
|
| 572 |
+
)
|
| 573 |
+
: PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id,
|
| 574 |
+
make_Coord(0, 0)) {}
|
| 575 |
+
|
| 576 |
+
/// Overrides the internal iteration index
|
| 577 |
+
CUTLASS_HOST_DEVICE
|
| 578 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 579 |
+
|
| 580 |
+
/// Adds a pointer offset in units of Element
|
| 581 |
+
CUTLASS_HOST_DEVICE
|
| 582 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 583 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 587 |
+
/// tiles
|
| 588 |
+
CUTLASS_HOST_DEVICE
|
| 589 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 590 |
+
iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
|
| 591 |
+
}
|
| 592 |
+
|
| 593 |
+
/// Returns a pointer
|
| 594 |
+
CUTLASS_HOST_DEVICE
|
| 595 |
+
AccessType *get() const {
|
| 596 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
/// Advances to the next tile in memory.
|
| 600 |
+
///
|
| 601 |
+
/// The first time this method is called, predicates are updated, and the
|
| 602 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 603 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 604 |
+
/// pointer.
|
| 605 |
+
CUTLASS_HOST_DEVICE
|
| 606 |
+
PredicatedTileAccessIterator2dThreadTile &operator++() {
|
| 607 |
+
++iterator_;
|
| 608 |
+
return *this;
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
+
/// Advances to the next tile in memory.
|
| 612 |
+
///
|
| 613 |
+
/// The first time this method is called, predicates are updated, and the
|
| 614 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 615 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 616 |
+
/// pointer.
|
| 617 |
+
CUTLASS_HOST_DEVICE
|
| 618 |
+
PredicatedTileAccessIterator2dThreadTile operator++(int) {
|
| 619 |
+
PredicatedTileAccessIterator2dThreadTile self(*this);
|
| 620 |
+
operator++();
|
| 621 |
+
return self;
|
| 622 |
+
}
|
| 623 |
+
|
| 624 |
+
/// Clears the predicate set efficiently
|
| 625 |
+
CUTLASS_HOST_DEVICE
|
| 626 |
+
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
| 627 |
+
|
| 628 |
+
/// Clears the predicate set efficiently
|
| 629 |
+
CUTLASS_HOST_DEVICE
|
| 630 |
+
void enable_mask() { iterator_.enable_mask(); }
|
| 631 |
+
|
| 632 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 633 |
+
CUTLASS_HOST_DEVICE
|
| 634 |
+
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
| 635 |
+
|
| 636 |
+
/// Gets the mask
|
| 637 |
+
CUTLASS_HOST_DEVICE
|
| 638 |
+
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
| 639 |
+
|
| 640 |
+
/// Returns whether access is valid or not
|
| 641 |
+
CUTLASS_HOST_DEVICE
|
| 642 |
+
bool valid() {
|
| 643 |
+
return iterator_.valid();
|
| 644 |
+
}
|
| 645 |
+
};
|
| 646 |
+
|
| 647 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 648 |
+
|
| 649 |
+
/// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data.
|
| 650 |
+
///
|
| 651 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 652 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 653 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 654 |
+
/// MaskedTileIteratorConcept
|
| 655 |
+
///
|
| 656 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 657 |
+
typename ThreadMap_, typename AccessType_>
|
| 658 |
+
class PredicatedTileAccessIterator2dThreadTile<Shape_, Element_, layout::RowMajor,
|
| 659 |
+
AdvanceRank, ThreadMap_, AccessType_> {
|
| 660 |
+
public:
|
| 661 |
+
static_assert(
|
| 662 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 663 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 664 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 665 |
+
|
| 666 |
+
using Shape = Shape_;
|
| 667 |
+
using Element = Element_;
|
| 668 |
+
using Layout = layout::RowMajor;
|
| 669 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 670 |
+
using ThreadMap = ThreadMap_;
|
| 671 |
+
using AccessType = AccessType_;
|
| 672 |
+
|
| 673 |
+
using Index = typename Layout::Index;
|
| 674 |
+
using LongIndex = typename Layout::LongIndex;
|
| 675 |
+
|
| 676 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 677 |
+
using TensorView = TensorView<Element, Layout>;
|
| 678 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 679 |
+
|
| 680 |
+
using Pointer = Element *;
|
| 681 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 682 |
+
|
| 683 |
+
using UnderlyingIterator = PredicatedTileAccessIterator2dThreadTile<
|
| 684 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
|
| 685 |
+
layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>;
|
| 686 |
+
|
| 687 |
+
/// Predicate vector stores mask to guard accesses
|
| 688 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 689 |
+
|
| 690 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 691 |
+
class Params {
|
| 692 |
+
private:
|
| 693 |
+
friend PredicatedTileAccessIterator2dThreadTile;
|
| 694 |
+
|
| 695 |
+
/// Parameters object
|
| 696 |
+
typename UnderlyingIterator::Params params_;
|
| 697 |
+
|
| 698 |
+
public:
|
| 699 |
+
|
| 700 |
+
/// Default ctor
|
| 701 |
+
CUTLASS_HOST_DEVICE
|
| 702 |
+
Params() { }
|
| 703 |
+
|
| 704 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 705 |
+
CUTLASS_HOST_DEVICE
|
| 706 |
+
Params(Layout const &layout)
|
| 707 |
+
: params_(layout::PitchLinear(layout.stride(0))){}
|
| 708 |
+
|
| 709 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 710 |
+
CUTLASS_HOST_DEVICE
|
| 711 |
+
Params(typename UnderlyingIterator::Params::Base const &base)
|
| 712 |
+
: params_(base) {}
|
| 713 |
+
};
|
| 714 |
+
|
| 715 |
+
private:
|
| 716 |
+
//
|
| 717 |
+
// Data members
|
| 718 |
+
//
|
| 719 |
+
|
| 720 |
+
/// Underlying pitch-linear tile iterator
|
| 721 |
+
UnderlyingIterator iterator_;
|
| 722 |
+
|
| 723 |
+
public:
|
| 724 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 725 |
+
/// and thread ID
|
| 726 |
+
CUTLASS_HOST_DEVICE
|
| 727 |
+
PredicatedTileAccessIterator2dThreadTile(
|
| 728 |
+
///< Precomputed parameters object
|
| 729 |
+
Params const ¶ms,
|
| 730 |
+
///< Pointer to start of tensor
|
| 731 |
+
Pointer pointer,
|
| 732 |
+
///< Extent of tensor
|
| 733 |
+
TensorCoord extent,
|
| 734 |
+
///< ID of each participating thread
|
| 735 |
+
int thread_id,
|
| 736 |
+
///< Initial offset of threadblock
|
| 737 |
+
TensorCoord const &threadblock_offset)
|
| 738 |
+
: iterator_(params.params_, pointer,
|
| 739 |
+
layout::PitchLinearCoord(extent.column(), extent.row()),
|
| 740 |
+
thread_id,
|
| 741 |
+
layout::PitchLinearCoord(threadblock_offset.column(),
|
| 742 |
+
threadblock_offset.row())) {}
|
| 743 |
+
|
| 744 |
+
/// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset
|
| 745 |
+
CUTLASS_HOST_DEVICE
|
| 746 |
+
PredicatedTileAccessIterator2dThreadTile(
|
| 747 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 748 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 749 |
+
TensorCoord extent, ///< Extent of tensor
|
| 750 |
+
int thread_id ///< ID of each participating thread
|
| 751 |
+
)
|
| 752 |
+
: PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id,
|
| 753 |
+
make_Coord(0, 0)) {}
|
| 754 |
+
|
| 755 |
+
/// Overrides the internal iteration index
|
| 756 |
+
CUTLASS_HOST_DEVICE
|
| 757 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 758 |
+
|
| 759 |
+
/// Adds a pointer offset in units of Element
|
| 760 |
+
CUTLASS_HOST_DEVICE
|
| 761 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 762 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 763 |
+
}
|
| 764 |
+
|
| 765 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 766 |
+
/// tiles
|
| 767 |
+
CUTLASS_HOST_DEVICE
|
| 768 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 769 |
+
iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
|
| 770 |
+
}
|
| 771 |
+
|
| 772 |
+
/// Returns a pointer
|
| 773 |
+
CUTLASS_HOST_DEVICE
|
| 774 |
+
AccessType *get() const {
|
| 775 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 776 |
+
}
|
| 777 |
+
|
| 778 |
+
/// Advances to the next tile in memory.
|
| 779 |
+
///
|
| 780 |
+
/// The first time this method is called, predicates are updated, and the
|
| 781 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 782 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 783 |
+
/// pointer.
|
| 784 |
+
CUTLASS_HOST_DEVICE
|
| 785 |
+
PredicatedTileAccessIterator2dThreadTile &operator++() {
|
| 786 |
+
++iterator_;
|
| 787 |
+
return *this;
|
| 788 |
+
}
|
| 789 |
+
|
| 790 |
+
/// Advances to the next tile in memory.
|
| 791 |
+
///
|
| 792 |
+
/// The first time this method is called, predicates are updated, and the
|
| 793 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 794 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 795 |
+
/// pointer.
|
| 796 |
+
CUTLASS_HOST_DEVICE
|
| 797 |
+
PredicatedTileAccessIterator2dThreadTile operator++(int) {
|
| 798 |
+
PredicatedTileAccessIterator2dThreadTile self(*this);
|
| 799 |
+
operator++();
|
| 800 |
+
return self;
|
| 801 |
+
}
|
| 802 |
+
|
| 803 |
+
/// Clears the predicate set efficiently
|
| 804 |
+
CUTLASS_HOST_DEVICE
|
| 805 |
+
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
| 806 |
+
|
| 807 |
+
/// Clears the predicate set efficiently
|
| 808 |
+
CUTLASS_HOST_DEVICE
|
| 809 |
+
void enable_mask() { iterator_.enable_mask(); }
|
| 810 |
+
|
| 811 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 812 |
+
CUTLASS_HOST_DEVICE
|
| 813 |
+
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
| 814 |
+
|
| 815 |
+
/// Gets the mask
|
| 816 |
+
CUTLASS_HOST_DEVICE
|
| 817 |
+
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
| 818 |
+
|
| 819 |
+
/// Returns whether access is valid or not
|
| 820 |
+
CUTLASS_HOST_DEVICE
|
| 821 |
+
bool valid() {
|
| 822 |
+
return iterator_.valid();
|
| 823 |
+
}
|
| 824 |
+
};
|
| 825 |
+
|
| 826 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 827 |
+
|
| 828 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 829 |
+
|
| 830 |
+
} // namespace threadblock
|
| 831 |
+
} // namespace transform
|
| 832 |
+
} // namespace cutlass
|
| 833 |
+
|
| 834 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief
|
| 33 |
+
*/
|
| 34 |
+
|
| 35 |
+
#pragma once
|
| 36 |
+
|
| 37 |
+
#include "cutlass/cutlass.h"
|
| 38 |
+
#include "cutlass/array.h"
|
| 39 |
+
#include "cutlass/detail/helper_macros.hpp"
|
| 40 |
+
#include "cutlass/layout/matrix.h"
|
| 41 |
+
#include "cutlass/layout/pitch_linear.h"
|
| 42 |
+
|
| 43 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
namespace cutlass {
|
| 46 |
+
namespace transform {
|
| 47 |
+
namespace threadblock {
|
| 48 |
+
|
| 49 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 50 |
+
|
| 51 |
+
/// Predicated tile access iterator descriptor object containing template dependent state
|
| 52 |
+
struct PredicatedTileAccessIteratorDesc {
|
| 53 |
+
|
| 54 |
+
int element_size_bits = -1;
|
| 55 |
+
int advance_rank = -1;
|
| 56 |
+
layout::PitchLinearCoord threadblock_shape;
|
| 57 |
+
layout::PitchLinearCoord threadmap_iterations;
|
| 58 |
+
layout::PitchLinearCoord threadmap_delta;
|
| 59 |
+
|
| 60 |
+
//
|
| 61 |
+
// Methods
|
| 62 |
+
//
|
| 63 |
+
|
| 64 |
+
PredicatedTileAccessIteratorDesc() = default;
|
| 65 |
+
|
| 66 |
+
CUTLASS_HOST_DEVICE
|
| 67 |
+
PredicatedTileAccessIteratorDesc(
|
| 68 |
+
int element_size_bits_,
|
| 69 |
+
int advance_rank_,
|
| 70 |
+
layout::PitchLinearCoord threadblock_shape_,
|
| 71 |
+
layout::PitchLinearCoord threadmap_iterations_,
|
| 72 |
+
layout::PitchLinearCoord threadmap_delta_
|
| 73 |
+
):
|
| 74 |
+
element_size_bits(element_size_bits_),
|
| 75 |
+
advance_rank(advance_rank_),
|
| 76 |
+
threadblock_shape(threadblock_shape_),
|
| 77 |
+
threadmap_iterations(threadmap_iterations_),
|
| 78 |
+
threadmap_delta(threadmap_delta_)
|
| 79 |
+
{
|
| 80 |
+
#if 0
|
| 81 |
+
printf("PredicatedTileAccessIteratorDesc(%d, %d, {%d, %d}, {%d, %d}, {%d, %d}})\n",
|
| 82 |
+
element_size_bits,
|
| 83 |
+
advance_rank,
|
| 84 |
+
threadblock_shape.contiguous(), threadblock_shape.strided(),
|
| 85 |
+
threadmap_iterations.contiguous(), threadmap_iterations.strided(),
|
| 86 |
+
threadmap_delta.contiguous(), threadmap_delta.strided());
|
| 87 |
+
#endif
|
| 88 |
+
}
|
| 89 |
+
};
|
| 90 |
+
|
| 91 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 92 |
+
/// Helper template to construct an PredicatedTileAccessIteratorDesc from a template
|
| 93 |
+
// dependent state
|
| 94 |
+
template <
|
| 95 |
+
typename Shape, typename Element, typename Layout,
|
| 96 |
+
int AdvanceRank, typename ThreadMap>
|
| 97 |
+
struct MakePredicatedTileAccessIteratorDesc;
|
| 98 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 99 |
+
|
| 100 |
+
/// Specialization of PredicatedTileAccessIterator for pitch-linear data.
|
| 101 |
+
template <
|
| 102 |
+
typename Shape, typename Element, int AdvanceRank,
|
| 103 |
+
typename ThreadMap>
|
| 104 |
+
struct MakePredicatedTileAccessIteratorDesc <
|
| 105 |
+
Shape, Element, layout::PitchLinear, AdvanceRank, ThreadMap> {
|
| 106 |
+
|
| 107 |
+
CUTLASS_HOST_DEVICE
|
| 108 |
+
PredicatedTileAccessIteratorDesc operator()() {
|
| 109 |
+
|
| 110 |
+
return PredicatedTileAccessIteratorDesc(
|
| 111 |
+
sizeof_bits<Element>::value,
|
| 112 |
+
AdvanceRank,
|
| 113 |
+
{Shape::kContiguous, Shape::kStrided},
|
| 114 |
+
{ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided},
|
| 115 |
+
{ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided}
|
| 116 |
+
);
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
};
|
| 120 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 121 |
+
|
| 122 |
+
/// Specialization of PredicatedTileAccessIterator for column-major data.
|
| 123 |
+
template <
|
| 124 |
+
typename Shape, typename Element, int AdvanceRank,
|
| 125 |
+
typename ThreadMap>
|
| 126 |
+
struct MakePredicatedTileAccessIteratorDesc <
|
| 127 |
+
Shape, Element, layout::ColumnMajor, AdvanceRank, ThreadMap> {
|
| 128 |
+
|
| 129 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 130 |
+
|
| 131 |
+
using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc<
|
| 132 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
|
| 133 |
+
layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap>;
|
| 134 |
+
|
| 135 |
+
CUTLASS_HOST_DEVICE
|
| 136 |
+
PredicatedTileAccessIteratorDesc operator()() {
|
| 137 |
+
|
| 138 |
+
return UnderlyingMakeOperator()();
|
| 139 |
+
}
|
| 140 |
+
};
|
| 141 |
+
|
| 142 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 143 |
+
|
| 144 |
+
/// Specialization of PredicatedTileAccessIterator for row-major data.
|
| 145 |
+
template <
|
| 146 |
+
typename Shape, typename Element, int AdvanceRank,
|
| 147 |
+
typename ThreadMap>
|
| 148 |
+
struct MakePredicatedTileAccessIteratorDesc <
|
| 149 |
+
Shape, Element, layout::RowMajor, AdvanceRank, ThreadMap> {
|
| 150 |
+
|
| 151 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 152 |
+
|
| 153 |
+
using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc<
|
| 154 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
|
| 155 |
+
layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap>;
|
| 156 |
+
|
| 157 |
+
CUTLASS_HOST_DEVICE
|
| 158 |
+
PredicatedTileAccessIteratorDesc operator()() {
|
| 159 |
+
|
| 160 |
+
return UnderlyingMakeOperator()();
|
| 161 |
+
}
|
| 162 |
+
};
|
| 163 |
+
|
| 164 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 165 |
+
|
| 166 |
+
/// Specialization of PredicatedTileAccessIterator for column-major interleaved data.
|
| 167 |
+
template <
|
| 168 |
+
typename Shape, typename Element, int AdvanceRank,
|
| 169 |
+
typename ThreadMap, int InterleavedK>
|
| 170 |
+
struct MakePredicatedTileAccessIteratorDesc <
|
| 171 |
+
Shape, Element, layout::ColumnMajorInterleaved<InterleavedK>, AdvanceRank, ThreadMap> {
|
| 172 |
+
|
| 173 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 174 |
+
static int const kInterleavedK = InterleavedK;
|
| 175 |
+
|
| 176 |
+
using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc<
|
| 177 |
+
layout::PitchLinearShape<Shape::kRow * kInterleavedK, Shape::kColumn / kInterleavedK>, Element,
|
| 178 |
+
layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap>;
|
| 179 |
+
|
| 180 |
+
CUTLASS_HOST_DEVICE
|
| 181 |
+
PredicatedTileAccessIteratorDesc operator()() {
|
| 182 |
+
|
| 183 |
+
return UnderlyingMakeOperator()();
|
| 184 |
+
}
|
| 185 |
+
};
|
| 186 |
+
|
| 187 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 188 |
+
|
| 189 |
+
/// Specialization of PredicatedTileAccessIterator for roww-major interleaved data.
|
| 190 |
+
template <
|
| 191 |
+
typename Shape, typename Element, int AdvanceRank,
|
| 192 |
+
typename ThreadMap, int InterleavedK>
|
| 193 |
+
struct MakePredicatedTileAccessIteratorDesc <
|
| 194 |
+
Shape, Element, layout::RowMajorInterleaved<InterleavedK>, AdvanceRank, ThreadMap> {
|
| 195 |
+
|
| 196 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 197 |
+
static int const kInterleavedK = InterleavedK;
|
| 198 |
+
|
| 199 |
+
using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc<
|
| 200 |
+
layout::PitchLinearShape<Shape::kColumn * kInterleavedK, Shape::kRow / kInterleavedK>, Element,
|
| 201 |
+
layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap>;
|
| 202 |
+
|
| 203 |
+
CUTLASS_HOST_DEVICE
|
| 204 |
+
PredicatedTileAccessIteratorDesc operator()() {
|
| 205 |
+
|
| 206 |
+
return UnderlyingMakeOperator()();
|
| 207 |
+
}
|
| 208 |
+
};
|
| 209 |
+
|
| 210 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 211 |
+
|
| 212 |
+
//
|
| 213 |
+
// Parameters struct
|
| 214 |
+
//
|
| 215 |
+
|
| 216 |
+
struct PredicatedTileAccessIteratorParams {
|
| 217 |
+
|
| 218 |
+
using Index = int32_t;
|
| 219 |
+
using LongIndex = int64_t;
|
| 220 |
+
|
| 221 |
+
//
|
| 222 |
+
// Data members
|
| 223 |
+
//
|
| 224 |
+
/// stride of pitch-linear layout (units of Element)
|
| 225 |
+
LongIndex stride_ = 0;
|
| 226 |
+
/// amount (in byte) to increment pointer to move to next access along
|
| 227 |
+
/// strided dimension
|
| 228 |
+
LongIndex inc_strided_ = 0;
|
| 229 |
+
/// amount (in byte) to increment pointer from last access to first access
|
| 230 |
+
/// of next tile
|
| 231 |
+
LongIndex inc_next_ = 0;
|
| 232 |
+
/// amount (in byte) to increment pointer from first access of current tile
|
| 233 |
+
/// to first access of next tile
|
| 234 |
+
LongIndex inc_advance_ = 0;
|
| 235 |
+
|
| 236 |
+
//
|
| 237 |
+
// Methods
|
| 238 |
+
//
|
| 239 |
+
|
| 240 |
+
CUTLASS_HOST_DEVICE
|
| 241 |
+
Status initialize(LongIndex stride, PredicatedTileAccessIteratorDesc desc) {
|
| 242 |
+
CUTLASS_ASSERT(desc.element_size_bits > 0);
|
| 243 |
+
CUTLASS_ASSERT(desc.advance_rank == 0 || desc.advance_rank == 1);
|
| 244 |
+
|
| 245 |
+
stride_ = stride;
|
| 246 |
+
|
| 247 |
+
inc_strided_ = (LongIndex(stride_) * desc.threadmap_delta.strided()) *
|
| 248 |
+
desc.element_size_bits / 8;
|
| 249 |
+
|
| 250 |
+
if (desc.advance_rank) {
|
| 251 |
+
// advance along strided dimension
|
| 252 |
+
inc_advance_ =
|
| 253 |
+
desc.threadblock_shape.strided() * LongIndex(stride_) * desc.element_size_bits / 8;
|
| 254 |
+
} else {
|
| 255 |
+
// advance along contiguous dimension
|
| 256 |
+
inc_advance_ = desc.threadblock_shape.contiguous() * desc.element_size_bits / 8;
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
inc_next_ = inc_advance_ - LongIndex(desc.threadmap_iterations.strided() - 1) *
|
| 260 |
+
desc.threadmap_delta.strided() * LongIndex(stride_) *
|
| 261 |
+
desc.element_size_bits / 8;
|
| 262 |
+
|
| 263 |
+
return Status::kSuccess;
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
CUTLASS_HOST_DEVICE
|
| 267 |
+
Status initialize(Index stride, PredicatedTileAccessIteratorDesc desc) {
|
| 268 |
+
return initialize(LongIndex(stride), desc);
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
PredicatedTileAccessIteratorParams() = default;
|
| 272 |
+
|
| 273 |
+
CUTLASS_HOST_DEVICE
|
| 274 |
+
PredicatedTileAccessIteratorParams(Index stride, PredicatedTileAccessIteratorDesc desc) {
|
| 275 |
+
initialize(stride, desc);
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
CUTLASS_HOST_DEVICE
|
| 279 |
+
PredicatedTileAccessIteratorParams(LongIndex stride, PredicatedTileAccessIteratorDesc desc) {
|
| 280 |
+
initialize(stride, desc);
|
| 281 |
+
}
|
| 282 |
+
};
|
| 283 |
+
|
| 284 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 285 |
+
|
| 286 |
+
} // namespace threadblock
|
| 287 |
+
} // namespace transform
|
| 288 |
+
} // namespace cutlass
|
| 289 |
+
|
| 290 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h
ADDED
|
@@ -0,0 +1,892 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates calculating the address and predicates to the load of tiles
|
| 33 |
+
from pitch-linear rank=2 tensors.
|
| 34 |
+
|
| 35 |
+
This iterator uses masks to guard out-of-bounds accesses and visits the last
|
| 36 |
+
"residue" tile first, with the objective of minimizing predicate mask updates
|
| 37 |
+
during steady-state operation.
|
| 38 |
+
|
| 39 |
+
A precomputed "Params" object minimizes the amount of state that must be
|
| 40 |
+
stored in registers, and integer addition is used to advance the pointer
|
| 41 |
+
through memory.
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
*/
|
| 45 |
+
|
| 46 |
+
#pragma once
|
| 47 |
+
|
| 48 |
+
#include "cutlass/blas3.h"
|
| 49 |
+
#include "cutlass/layout/matrix.h"
|
| 50 |
+
#include "cutlass/layout/pitch_linear.h"
|
| 51 |
+
#include "cutlass/matrix_shape.h"
|
| 52 |
+
#include "cutlass/predicate_vector.h"
|
| 53 |
+
#include "cutlass/tensor_ref.h"
|
| 54 |
+
#include "cutlass/tensor_view.h"
|
| 55 |
+
|
| 56 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 57 |
+
|
| 58 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
|
| 60 |
+
namespace cutlass {
|
| 61 |
+
namespace transform {
|
| 62 |
+
namespace threadblock {
|
| 63 |
+
|
| 64 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 65 |
+
|
| 66 |
+
/// PredicatedTileAccessIteratorTriangularMatrix
|
| 67 |
+
///
|
| 68 |
+
template <typename Shape, typename Element, typename Layout,
|
| 69 |
+
int AdvanceRank, typename ThreadMap,
|
| 70 |
+
SideMode kSideMode, FillMode kFillMode, DiagType kDiagType,
|
| 71 |
+
typename AccessType>
|
| 72 |
+
class PredicatedTileAccessIteratorTriangularMatrix;
|
| 73 |
+
|
| 74 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 75 |
+
|
| 76 |
+
/// Specialization of PredicatedTileAccessIteratorTriangularMatrix for pitch-linear data.
|
| 77 |
+
///
|
| 78 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 79 |
+
typename ThreadMap_, SideMode kSideMode, FillMode kFillMode, DiagType kDiagType, typename AccessType_>
|
| 80 |
+
class PredicatedTileAccessIteratorTriangularMatrix<Shape_, Element_, layout::PitchLinear,
|
| 81 |
+
AdvanceRank, ThreadMap_, kSideMode, kFillMode, kDiagType, AccessType_> {
|
| 82 |
+
public:
|
| 83 |
+
static_assert(
|
| 84 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 85 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 86 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 87 |
+
|
| 88 |
+
using Shape = Shape_;
|
| 89 |
+
using Element = Element_;
|
| 90 |
+
using Layout = layout::PitchLinear;
|
| 91 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 92 |
+
using ThreadMap = ThreadMap_;
|
| 93 |
+
using AccessType = AccessType_;
|
| 94 |
+
|
| 95 |
+
using Index = typename Layout::Index;
|
| 96 |
+
using LongIndex = typename Layout::LongIndex;
|
| 97 |
+
using StrideIndex = typename Layout::Stride::Index;
|
| 98 |
+
|
| 99 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 100 |
+
using TensorView = TensorView<Element, Layout>;
|
| 101 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 102 |
+
|
| 103 |
+
using Pointer = Element *;
|
| 104 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 105 |
+
|
| 106 |
+
static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
|
| 107 |
+
|
| 108 |
+
using CompareOp = typename TrMatrixCompareOp<kFillMode, kDiagType>::Type;
|
| 109 |
+
|
| 110 |
+
static_assert( kFillMode == FillMode::kFull ||
|
| 111 |
+
((kFillMode == FillMode::kLower || kFillMode == FillMode::kUpper) && AccessType::kElements == 1),
|
| 112 |
+
"BLAS3 iterator for the triangular/symmetric matrix must use AccessType::kElements as 1");
|
| 113 |
+
|
| 114 |
+
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
|
| 115 |
+
"Vectors implied by the thread map must be divisible by the access type.");
|
| 116 |
+
|
| 117 |
+
static int const kPredicatesPerByte = 4;
|
| 118 |
+
static int const kPredicatesPerWord = 4 * kPredicatesPerByte;
|
| 119 |
+
|
| 120 |
+
static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector;
|
| 121 |
+
|
| 122 |
+
/// Number of 32b words containing predicates
|
| 123 |
+
static int const kPredicateByteCount =
|
| 124 |
+
(kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte;
|
| 125 |
+
static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4;
|
| 126 |
+
|
| 127 |
+
static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u;
|
| 128 |
+
|
| 129 |
+
static_assert(kPredicateWordCount <= 4, "Too many predicates.");
|
| 130 |
+
|
| 131 |
+
/// Predicate vector stores mask to guard accesses
|
| 132 |
+
using Mask = Array<uint32_t, kPredicateWordCount>;
|
| 133 |
+
|
| 134 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 135 |
+
class Params {
|
| 136 |
+
public:
|
| 137 |
+
friend PredicatedTileAccessIteratorTriangularMatrix;
|
| 138 |
+
|
| 139 |
+
private:
|
| 140 |
+
/// stride of pitch-linear layout (units of Element)
|
| 141 |
+
StrideIndex stride_;
|
| 142 |
+
/// (true) pitch-linear layout is mapped to row-major matrix
|
| 143 |
+
/// (false) pitch-linear layout is mapped to column-major matrix
|
| 144 |
+
bool is_row_major_;
|
| 145 |
+
/// for vectorized access across the diagonal boundary guard condition is
|
| 146 |
+
/// checked for the element on the boundary
|
| 147 |
+
int access_diagonal_boundary_;
|
| 148 |
+
/// amount (in byte) to increment pointer to move to next access along
|
| 149 |
+
/// strided dimension
|
| 150 |
+
LongIndex inc_strided_;
|
| 151 |
+
/// amount (in byte) to increment pointer from last access to first access
|
| 152 |
+
/// of next tile
|
| 153 |
+
LongIndex inc_next_;
|
| 154 |
+
/// amount (in byte) to increment pointer from first access of current tile
|
| 155 |
+
/// to first access of next tile
|
| 156 |
+
LongIndex inc_advance_;
|
| 157 |
+
|
| 158 |
+
public:
|
| 159 |
+
|
| 160 |
+
// Default ctor
|
| 161 |
+
CUTLASS_HOST_DEVICE
|
| 162 |
+
Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0), is_row_major_(false), access_diagonal_boundary_(0) { }
|
| 163 |
+
|
| 164 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 165 |
+
CUTLASS_HOST_DEVICE
|
| 166 |
+
Params(Layout const &layout, bool is_row_major, int access_diagonal_boundary) :
|
| 167 |
+
stride_(layout.stride(0)), is_row_major_(is_row_major), access_diagonal_boundary_(access_diagonal_boundary) {
|
| 168 |
+
|
| 169 |
+
inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) *
|
| 170 |
+
sizeof_bits<Element>::value / 8;
|
| 171 |
+
|
| 172 |
+
if (kAdvanceRank) {
|
| 173 |
+
// advance along strided dimension
|
| 174 |
+
inc_advance_ =
|
| 175 |
+
Shape::kStrided * LongIndex(stride_) * sizeof_bits<Element>::value / 8;
|
| 176 |
+
} else {
|
| 177 |
+
// advance along contiguous dimension
|
| 178 |
+
inc_advance_ = Shape::kContiguous * sizeof_bits<Element>::value / 8;
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) *
|
| 182 |
+
ThreadMap::Delta::kStrided * LongIndex(stride_) *
|
| 183 |
+
sizeof_bits<Element>::value / 8;
|
| 184 |
+
|
| 185 |
+
};
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
};
|
| 189 |
+
|
| 190 |
+
private:
|
| 191 |
+
/// Internal pointer type permits fast address arithmetic
|
| 192 |
+
using BytePointer = char *;
|
| 193 |
+
|
| 194 |
+
private:
|
| 195 |
+
//
|
| 196 |
+
// Data members
|
| 197 |
+
//
|
| 198 |
+
|
| 199 |
+
/// Parameters object with precomputed internal state
|
| 200 |
+
Params const ¶ms_;
|
| 201 |
+
|
| 202 |
+
/// Internal pointer to first access of tile
|
| 203 |
+
BytePointer pointer_;
|
| 204 |
+
|
| 205 |
+
/// Guard predicates
|
| 206 |
+
uint32_t predicates_[kPredicateWordCount];
|
| 207 |
+
|
| 208 |
+
/// Track global memory addresses on the diagonal
|
| 209 |
+
/// To ignore imag part for diagonal elements of hermitian matrices
|
| 210 |
+
uint32_t predicates_onDiag_[kPredicateWordCount];
|
| 211 |
+
|
| 212 |
+
/// Size of tensor
|
| 213 |
+
TensorCoord extent_;
|
| 214 |
+
|
| 215 |
+
/// Initial offset for each thread
|
| 216 |
+
TensorCoord thread_offset_;
|
| 217 |
+
|
| 218 |
+
/// Iteration along vectors implied by the thread map
|
| 219 |
+
int iteration_vector_;
|
| 220 |
+
|
| 221 |
+
/// Iteration in the contiguous dimension
|
| 222 |
+
int iteration_contiguous_;
|
| 223 |
+
|
| 224 |
+
/// Iteration in the strided dimension
|
| 225 |
+
int iteration_strided_;
|
| 226 |
+
|
| 227 |
+
private:
|
| 228 |
+
/// Computes predicates based on internally tracked per-thread offset.
|
| 229 |
+
CUTLASS_DEVICE
|
| 230 |
+
void compute_predicates_(
|
| 231 |
+
/// Extent of the matrix window
|
| 232 |
+
TensorCoord extent) {
|
| 233 |
+
|
| 234 |
+
CUTLASS_PRAGMA_UNROLL
|
| 235 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 236 |
+
predicates_[i] = 0u;
|
| 237 |
+
predicates_onDiag_[i] = 0u;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
CompareOp compare_op;
|
| 241 |
+
|
| 242 |
+
CUTLASS_PRAGMA_UNROLL
|
| 243 |
+
for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) {
|
| 244 |
+
|
| 245 |
+
int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
|
| 246 |
+
|
| 247 |
+
int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
|
| 248 |
+
|
| 249 |
+
int c = access_residual / kAccessesPerVector;
|
| 250 |
+
int v = access_residual % kAccessesPerVector;
|
| 251 |
+
|
| 252 |
+
TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements,
|
| 253 |
+
s * ThreadMap::Delta::kStrided);
|
| 254 |
+
|
| 255 |
+
TensorCoord coord = thread_offset_ + iteration_coord;
|
| 256 |
+
|
| 257 |
+
bool guard;
|
| 258 |
+
bool onDiag = false;
|
| 259 |
+
|
| 260 |
+
guard = ((coord.strided() < extent.strided()) &&
|
| 261 |
+
(coord.contiguous() < extent.contiguous()));
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
// guard access on the wrong side of the triagular matrix diagonal
|
| 265 |
+
if (kFillMode == FillMode::kLower || kFillMode == FillMode::kUpper) {
|
| 266 |
+
coord += TensorCoord{params_.access_diagonal_boundary_, 0};
|
| 267 |
+
|
| 268 |
+
bool triagular_guard_row_major = compare_op(coord.strided(), coord.contiguous()) | !params_.is_row_major_;
|
| 269 |
+
bool triagular_guard_col_major = compare_op(coord.contiguous(), coord.strided()) | params_.is_row_major_;
|
| 270 |
+
|
| 271 |
+
guard = guard && triagular_guard_row_major && triagular_guard_col_major;
|
| 272 |
+
|
| 273 |
+
if (kDiagType == DiagType::kUnit) {
|
| 274 |
+
onDiag = (guard && coord.strided() == coord.contiguous()) ? true : false;
|
| 275 |
+
}
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
int pred_idx_onDiag = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s);
|
| 279 |
+
int word_idx_onDiag = pred_idx_onDiag / kPredicatesPerWord;
|
| 280 |
+
int residual_onDiag = pred_idx_onDiag % kPredicatesPerWord;
|
| 281 |
+
int byte_idx_onDiag = residual_onDiag / kPredicatesPerByte;
|
| 282 |
+
int bit_idx_onDiag = residual_onDiag % kPredicatesPerByte;
|
| 283 |
+
|
| 284 |
+
predicates_onDiag_[word_idx_onDiag] |= (unsigned(onDiag) << (byte_idx_onDiag * 8 + bit_idx_onDiag));
|
| 285 |
+
|
| 286 |
+
int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s);
|
| 287 |
+
|
| 288 |
+
int word_idx = pred_idx / kPredicatesPerWord;
|
| 289 |
+
int residual = pred_idx % kPredicatesPerWord;
|
| 290 |
+
int byte_idx = residual / kPredicatesPerByte;
|
| 291 |
+
int bit_idx = residual % kPredicatesPerByte;
|
| 292 |
+
|
| 293 |
+
predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx));
|
| 294 |
+
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
public:
|
| 300 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 301 |
+
/// and thread ID
|
| 302 |
+
CUTLASS_HOST_DEVICE
|
| 303 |
+
PredicatedTileAccessIteratorTriangularMatrix(
|
| 304 |
+
/// Precomputed parameters object
|
| 305 |
+
Params const ¶ms,
|
| 306 |
+
/// Pointer to start of tensor
|
| 307 |
+
Pointer pointer,
|
| 308 |
+
/// Extent of tensor
|
| 309 |
+
TensorCoord extent,
|
| 310 |
+
/// ID of each participating thread
|
| 311 |
+
int thread_id,
|
| 312 |
+
/// Initial offset of threadblock
|
| 313 |
+
TensorCoord const &threadblock_offset)
|
| 314 |
+
: params_(params),
|
| 315 |
+
pointer_(reinterpret_cast<BytePointer>(const_cast<NonConstPointer>(pointer))),
|
| 316 |
+
extent_(extent) {
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
// Per-thread offset in logical coordinates of tensor
|
| 320 |
+
thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id);
|
| 321 |
+
|
| 322 |
+
// update internal pointers
|
| 323 |
+
Layout layout(params_.stride_);
|
| 324 |
+
add_pointer_offset(layout(thread_offset_));
|
| 325 |
+
|
| 326 |
+
compute_predicates_(extent_);
|
| 327 |
+
|
| 328 |
+
set_iteration_index(0);
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
/// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset
|
| 332 |
+
CUTLASS_HOST_DEVICE
|
| 333 |
+
PredicatedTileAccessIteratorTriangularMatrix(
|
| 334 |
+
/// Precomputed parameters object
|
| 335 |
+
Params const ¶ms,
|
| 336 |
+
/// Pointer to start of tensor
|
| 337 |
+
Pointer pointer,
|
| 338 |
+
/// Extent of tensor
|
| 339 |
+
TensorCoord extent,
|
| 340 |
+
///< ID of each participating thread
|
| 341 |
+
int thread_id)
|
| 342 |
+
: PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id,
|
| 343 |
+
make_Coord(0, 0)) {}
|
| 344 |
+
|
| 345 |
+
/// Overrides the internal iteration index
|
| 346 |
+
CUTLASS_HOST_DEVICE
|
| 347 |
+
void set_iteration_index(int index) {
|
| 348 |
+
|
| 349 |
+
iteration_vector_ = index % kAccessesPerVector;
|
| 350 |
+
int residual_access = index / kAccessesPerVector;
|
| 351 |
+
|
| 352 |
+
iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
|
| 353 |
+
iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
|
| 354 |
+
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
/// Adds a pointer offset in units of Element
|
| 358 |
+
CUTLASS_HOST_DEVICE
|
| 359 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 360 |
+
pointer_ += sizeof_bits<Element>::value * pointer_offset / 8;
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
|
| 364 |
+
CUTLASS_DEVICE
|
| 365 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 366 |
+
|
| 367 |
+
if (kAdvanceRank) {
|
| 368 |
+
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided());
|
| 369 |
+
pointer_ += Shape::kContiguous * tile_offset.contiguous();
|
| 370 |
+
thread_offset_ += TensorCoord{0, Shape::kStrided * tile_offset.strided()};
|
| 371 |
+
} else {
|
| 372 |
+
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous());
|
| 373 |
+
pointer_ += Shape::kStrided * tile_offset.strided();
|
| 374 |
+
thread_offset_ += TensorCoord{Shape::kContiguous * tile_offset.contiguous(), 0};
|
| 375 |
+
}
|
| 376 |
+
|
| 377 |
+
compute_predicates_(extent_);
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
/// Returns a pointer
|
| 381 |
+
CUTLASS_HOST_DEVICE
|
| 382 |
+
AccessType *get() const {
|
| 383 |
+
return reinterpret_cast<AccessType *>(
|
| 384 |
+
pointer_ +
|
| 385 |
+
iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value) / 8) + iteration_vector_;
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
/// Increment and return an instance to self.
|
| 389 |
+
CUTLASS_HOST_DEVICE
|
| 390 |
+
PredicatedTileAccessIteratorTriangularMatrix &operator++() {
|
| 391 |
+
|
| 392 |
+
++iteration_vector_;
|
| 393 |
+
if (iteration_vector_ < kAccessesPerVector) {
|
| 394 |
+
return *this;
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
iteration_vector_ = 0;
|
| 398 |
+
++iteration_contiguous_;
|
| 399 |
+
|
| 400 |
+
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
|
| 401 |
+
return *this;
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
// Enter here only if (iteration_contiguous_ ==
|
| 405 |
+
// ThreadMap::Iteration::kContiguous)
|
| 406 |
+
iteration_contiguous_ = 0;
|
| 407 |
+
++iteration_strided_;
|
| 408 |
+
|
| 409 |
+
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
| 410 |
+
pointer_ += params_.inc_strided_;
|
| 411 |
+
return *this;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
// Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
|
| 415 |
+
// which means we enter the next tile.
|
| 416 |
+
iteration_strided_ = 0;
|
| 417 |
+
|
| 418 |
+
// advance to next tile
|
| 419 |
+
pointer_ += params_.inc_next_;
|
| 420 |
+
|
| 421 |
+
// now return to start tile - if the iterator is subsequently advanced, this
|
| 422 |
+
// subtraction as well as the subsequent integer addition are both elided by
|
| 423 |
+
// the compiler.
|
| 424 |
+
pointer_ -= params_.inc_advance_;
|
| 425 |
+
|
| 426 |
+
return *this;
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
/// Increment and return an instance to self.
|
| 430 |
+
CUTLASS_HOST_DEVICE
|
| 431 |
+
PredicatedTileAccessIteratorTriangularMatrix operator++(int) {
|
| 432 |
+
PredicatedTileAccessIteratorTriangularMatrix self(*this);
|
| 433 |
+
operator++();
|
| 434 |
+
return self;
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
/// Clears the predicate set efficiently
|
| 438 |
+
CUTLASS_HOST_DEVICE
|
| 439 |
+
void clear_mask(bool enable = true) {
|
| 440 |
+
CUTLASS_PRAGMA_UNROLL
|
| 441 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 442 |
+
predicates_[i] = enable ? 0u : predicates_[i];
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
/// Clears the predicate set efficiently
|
| 448 |
+
CUTLASS_HOST_DEVICE
|
| 449 |
+
void enable_mask() {
|
| 450 |
+
CUTLASS_PRAGMA_UNROLL
|
| 451 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 452 |
+
predicates_[i] = 0xffffffff;
|
| 453 |
+
}
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 457 |
+
CUTLASS_HOST_DEVICE
|
| 458 |
+
void set_mask(Mask const &mask) {
|
| 459 |
+
CUTLASS_PRAGMA_UNROLL
|
| 460 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 461 |
+
predicates_[i] = mask[i];
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
/// Gets the mask
|
| 467 |
+
CUTLASS_HOST_DEVICE
|
| 468 |
+
void get_mask(Mask &mask) {
|
| 469 |
+
CUTLASS_PRAGMA_UNROLL
|
| 470 |
+
for (int i = 0; i < kPredicateWordCount; ++i) {
|
| 471 |
+
mask[i] = predicates_[i];
|
| 472 |
+
}
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
/// Return if the address in on the diagonal
|
| 476 |
+
CUTLASS_HOST_DEVICE
|
| 477 |
+
bool getOnDiag() {
|
| 478 |
+
int pred_idx =
|
| 479 |
+
iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous);
|
| 480 |
+
|
| 481 |
+
int word_idx = pred_idx / kPredicatesPerWord;
|
| 482 |
+
int residual = pred_idx % kPredicatesPerWord;
|
| 483 |
+
int byte_idx = residual / kPredicatesPerByte;
|
| 484 |
+
int bit_idx = residual % kPredicatesPerByte;
|
| 485 |
+
|
| 486 |
+
bool pred = (predicates_onDiag_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0;
|
| 487 |
+
return pred;
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
/// Returns whether access is valid or not
|
| 491 |
+
CUTLASS_HOST_DEVICE
|
| 492 |
+
bool valid() {
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
int pred_idx =
|
| 496 |
+
iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous);
|
| 497 |
+
|
| 498 |
+
int word_idx = pred_idx / kPredicatesPerWord;
|
| 499 |
+
int residual = pred_idx % kPredicatesPerWord;
|
| 500 |
+
int byte_idx = residual / kPredicatesPerByte;
|
| 501 |
+
int bit_idx = residual % kPredicatesPerByte;
|
| 502 |
+
|
| 503 |
+
bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0;
|
| 504 |
+
return pred;
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
//return true;
|
| 508 |
+
}
|
| 509 |
+
};
|
| 510 |
+
|
| 511 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 512 |
+
|
| 513 |
+
/// Specialization of PredicatedTileAccessIteratorTriangularMatrix for column-major data.
|
| 514 |
+
///
|
| 515 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 516 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 517 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 518 |
+
/// MaskedTileIteratorConcept
|
| 519 |
+
///
|
| 520 |
+
template <typename Shape_, typename Element_, int AdvanceRank, typename ThreadMap_,
|
| 521 |
+
SideMode kSideMode, FillMode kFillMode, DiagType kDiagType,
|
| 522 |
+
typename AccessType_>
|
| 523 |
+
class PredicatedTileAccessIteratorTriangularMatrix<Shape_, Element_, layout::ColumnMajor,
|
| 524 |
+
AdvanceRank, ThreadMap_, kSideMode, kFillMode, kDiagType,
|
| 525 |
+
AccessType_> {
|
| 526 |
+
public:
|
| 527 |
+
static_assert(
|
| 528 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 529 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 530 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 531 |
+
|
| 532 |
+
using Shape = Shape_;
|
| 533 |
+
using Element = Element_;
|
| 534 |
+
using Layout = layout::ColumnMajor;
|
| 535 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 536 |
+
using ThreadMap = ThreadMap_;
|
| 537 |
+
using AccessType = AccessType_;
|
| 538 |
+
|
| 539 |
+
using Index = typename Layout::Index;
|
| 540 |
+
using LongIndex = typename Layout::LongIndex;
|
| 541 |
+
|
| 542 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 543 |
+
using TensorView = TensorView<Element, Layout>;
|
| 544 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 545 |
+
|
| 546 |
+
using Pointer = Element *;
|
| 547 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 548 |
+
|
| 549 |
+
using UnderlyingIterator = PredicatedTileAccessIteratorTriangularMatrix<
|
| 550 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
|
| 551 |
+
layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap,
|
| 552 |
+
kSideMode, kFillMode, kDiagType, AccessType>;
|
| 553 |
+
|
| 554 |
+
/// Predicate vector stores mask to guard accesses
|
| 555 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 556 |
+
|
| 557 |
+
static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
|
| 558 |
+
|
| 559 |
+
static int const kAccessDiagonalBoundary =
|
| 560 |
+
(kFillMode == FillMode::kLower) ? (AccessType::kElements - 1) : 0;
|
| 561 |
+
|
| 562 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 563 |
+
class Params {
|
| 564 |
+
private:
|
| 565 |
+
friend PredicatedTileAccessIteratorTriangularMatrix;
|
| 566 |
+
|
| 567 |
+
/// Parameters object
|
| 568 |
+
typename UnderlyingIterator::Params params_;
|
| 569 |
+
|
| 570 |
+
public:
|
| 571 |
+
|
| 572 |
+
/// Default ctor
|
| 573 |
+
CUTLASS_HOST_DEVICE
|
| 574 |
+
Params() { }
|
| 575 |
+
|
| 576 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 577 |
+
CUTLASS_HOST_DEVICE
|
| 578 |
+
Params(Layout const &layout)
|
| 579 |
+
: params_(layout::PitchLinear(layout.stride(0)), false, kAccessDiagonalBoundary){};
|
| 580 |
+
};
|
| 581 |
+
|
| 582 |
+
private:
|
| 583 |
+
//
|
| 584 |
+
// Data members
|
| 585 |
+
//
|
| 586 |
+
|
| 587 |
+
/// Underlying pitch-linear tile iterator
|
| 588 |
+
UnderlyingIterator iterator_;
|
| 589 |
+
|
| 590 |
+
public:
|
| 591 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 592 |
+
/// and thread ID
|
| 593 |
+
CUTLASS_HOST_DEVICE
|
| 594 |
+
PredicatedTileAccessIteratorTriangularMatrix(
|
| 595 |
+
///< Precomputed parameters object
|
| 596 |
+
Params const ¶ms,
|
| 597 |
+
///< Pointer to start of tensor
|
| 598 |
+
Pointer pointer,
|
| 599 |
+
///< Extent of tensor
|
| 600 |
+
TensorCoord extent,
|
| 601 |
+
///< ID of each participating thread
|
| 602 |
+
int thread_id,
|
| 603 |
+
///< Initial offset of threadblock
|
| 604 |
+
TensorCoord const &threadblock_offset)
|
| 605 |
+
: iterator_(params.params_, pointer,
|
| 606 |
+
layout::PitchLinearCoord(extent.row(), extent.column()),
|
| 607 |
+
thread_id,
|
| 608 |
+
layout::PitchLinearCoord(threadblock_offset.row(),
|
| 609 |
+
threadblock_offset.column())) {}
|
| 610 |
+
|
| 611 |
+
/// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset
|
| 612 |
+
CUTLASS_HOST_DEVICE
|
| 613 |
+
PredicatedTileAccessIteratorTriangularMatrix(
|
| 614 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 615 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 616 |
+
TensorCoord extent, ///< Extent of tensor
|
| 617 |
+
int thread_id ///< ID of each participating thread
|
| 618 |
+
)
|
| 619 |
+
: PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id,
|
| 620 |
+
make_Coord(0, 0)) {}
|
| 621 |
+
|
| 622 |
+
/// Overrides the internal iteration index
|
| 623 |
+
CUTLASS_HOST_DEVICE
|
| 624 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 625 |
+
|
| 626 |
+
/// Adds a pointer offset in units of Element
|
| 627 |
+
CUTLASS_HOST_DEVICE
|
| 628 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 629 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 630 |
+
}
|
| 631 |
+
|
| 632 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 633 |
+
/// tiles
|
| 634 |
+
CUTLASS_HOST_DEVICE
|
| 635 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 636 |
+
iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
/// Returns a pointer
|
| 640 |
+
CUTLASS_HOST_DEVICE
|
| 641 |
+
AccessType *get() const {
|
| 642 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
/// Advances to the next tile in memory.
|
| 646 |
+
///
|
| 647 |
+
/// The first time this method is called, predicates are updated, and the
|
| 648 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 649 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 650 |
+
/// pointer.
|
| 651 |
+
CUTLASS_HOST_DEVICE
|
| 652 |
+
PredicatedTileAccessIteratorTriangularMatrix &operator++() {
|
| 653 |
+
++iterator_;
|
| 654 |
+
return *this;
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
/// Advances to the next tile in memory.
|
| 658 |
+
///
|
| 659 |
+
/// The first time this method is called, predicates are updated, and the
|
| 660 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 661 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 662 |
+
/// pointer.
|
| 663 |
+
CUTLASS_HOST_DEVICE
|
| 664 |
+
PredicatedTileAccessIteratorTriangularMatrix operator++(int) {
|
| 665 |
+
PredicatedTileAccessIteratorTriangularMatrix self(*this);
|
| 666 |
+
operator++();
|
| 667 |
+
return self;
|
| 668 |
+
}
|
| 669 |
+
|
| 670 |
+
/// Clears the predicate set efficiently
|
| 671 |
+
CUTLASS_HOST_DEVICE
|
| 672 |
+
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
| 673 |
+
|
| 674 |
+
/// Clears the predicate set efficiently
|
| 675 |
+
CUTLASS_HOST_DEVICE
|
| 676 |
+
void enable_mask() { iterator_.enable_mask(); }
|
| 677 |
+
|
| 678 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 679 |
+
CUTLASS_HOST_DEVICE
|
| 680 |
+
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
| 681 |
+
|
| 682 |
+
/// Gets the mask
|
| 683 |
+
CUTLASS_HOST_DEVICE
|
| 684 |
+
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
| 685 |
+
|
| 686 |
+
/// Return if the address in on the diagonal
|
| 687 |
+
CUTLASS_HOST_DEVICE
|
| 688 |
+
bool getOnDiag() {
|
| 689 |
+
return iterator_.getOnDiag();
|
| 690 |
+
}
|
| 691 |
+
|
| 692 |
+
/// Returns whether access is valid or not
|
| 693 |
+
CUTLASS_HOST_DEVICE
|
| 694 |
+
bool valid() {
|
| 695 |
+
return iterator_.valid();
|
| 696 |
+
}
|
| 697 |
+
};
|
| 698 |
+
|
| 699 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 700 |
+
|
| 701 |
+
/// Specialization of PredicatedTileAccessIteratorTriangularMatrix for row-major data.
|
| 702 |
+
///
|
| 703 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 704 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 705 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 706 |
+
/// MaskedTileIteratorConcept
|
| 707 |
+
///
|
| 708 |
+
template <typename Shape_, typename Element_, int AdvanceRank, typename ThreadMap_,
|
| 709 |
+
SideMode kSideMode, FillMode kFillMode, DiagType kDiagType,
|
| 710 |
+
typename AccessType_>
|
| 711 |
+
class PredicatedTileAccessIteratorTriangularMatrix<Shape_, Element_, layout::RowMajor, AdvanceRank, ThreadMap_,
|
| 712 |
+
kSideMode, kFillMode, kDiagType, AccessType_> {
|
| 713 |
+
public:
|
| 714 |
+
static_assert(
|
| 715 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 716 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 717 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 718 |
+
|
| 719 |
+
using Shape = Shape_;
|
| 720 |
+
using Element = Element_;
|
| 721 |
+
using Layout = layout::RowMajor;
|
| 722 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 723 |
+
using ThreadMap = ThreadMap_;
|
| 724 |
+
using AccessType = AccessType_;
|
| 725 |
+
|
| 726 |
+
using Index = typename Layout::Index;
|
| 727 |
+
using LongIndex = typename Layout::LongIndex;
|
| 728 |
+
|
| 729 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 730 |
+
using TensorView = TensorView<Element, Layout>;
|
| 731 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 732 |
+
|
| 733 |
+
using Pointer = Element *;
|
| 734 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 735 |
+
|
| 736 |
+
using UnderlyingIterator = PredicatedTileAccessIteratorTriangularMatrix<
|
| 737 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
|
| 738 |
+
layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap,
|
| 739 |
+
kSideMode, kFillMode, kDiagType, AccessType>;
|
| 740 |
+
|
| 741 |
+
static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
|
| 742 |
+
|
| 743 |
+
static int const kAccessDiagonalBoundary =
|
| 744 |
+
(kFillMode == FillMode::kUpper) ? (AccessType::kElements - 1) : 0;
|
| 745 |
+
|
| 746 |
+
/// Predicate vector stores mask to guard accesses
|
| 747 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 748 |
+
|
| 749 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 750 |
+
class Params {
|
| 751 |
+
private:
|
| 752 |
+
friend PredicatedTileAccessIteratorTriangularMatrix;
|
| 753 |
+
|
| 754 |
+
/// Parameters object
|
| 755 |
+
typename UnderlyingIterator::Params params_;
|
| 756 |
+
|
| 757 |
+
public:
|
| 758 |
+
|
| 759 |
+
/// Default ctor
|
| 760 |
+
CUTLASS_HOST_DEVICE
|
| 761 |
+
Params() { }
|
| 762 |
+
|
| 763 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 764 |
+
CUTLASS_HOST_DEVICE
|
| 765 |
+
Params(Layout const &layout)
|
| 766 |
+
: params_(layout::PitchLinear(layout.stride(0)), true, kAccessDiagonalBoundary){};
|
| 767 |
+
};
|
| 768 |
+
|
| 769 |
+
private:
|
| 770 |
+
//
|
| 771 |
+
// Data members
|
| 772 |
+
//
|
| 773 |
+
|
| 774 |
+
/// Underlying pitch-linear tile iterator
|
| 775 |
+
UnderlyingIterator iterator_;
|
| 776 |
+
|
| 777 |
+
public:
|
| 778 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 779 |
+
/// and thread ID
|
| 780 |
+
CUTLASS_HOST_DEVICE
|
| 781 |
+
PredicatedTileAccessIteratorTriangularMatrix(
|
| 782 |
+
///< Precomputed parameters object
|
| 783 |
+
Params const ¶ms,
|
| 784 |
+
///< Pointer to start of tensor
|
| 785 |
+
Pointer pointer,
|
| 786 |
+
///< Extent of tensor
|
| 787 |
+
TensorCoord extent,
|
| 788 |
+
///< ID of each participating thread
|
| 789 |
+
int thread_id,
|
| 790 |
+
///< Initial offset of threadblock
|
| 791 |
+
TensorCoord const &threadblock_offset)
|
| 792 |
+
: iterator_(params.params_, pointer,
|
| 793 |
+
layout::PitchLinearCoord(extent.column(), extent.row()),
|
| 794 |
+
thread_id,
|
| 795 |
+
layout::PitchLinearCoord(threadblock_offset.column(),
|
| 796 |
+
threadblock_offset.row())) {}
|
| 797 |
+
|
| 798 |
+
/// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset
|
| 799 |
+
CUTLASS_HOST_DEVICE
|
| 800 |
+
PredicatedTileAccessIteratorTriangularMatrix(
|
| 801 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 802 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 803 |
+
TensorCoord extent, ///< Extent of tensor
|
| 804 |
+
int thread_id ///< ID of each participating thread
|
| 805 |
+
)
|
| 806 |
+
: PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id,
|
| 807 |
+
make_Coord(0, 0)) {}
|
| 808 |
+
|
| 809 |
+
/// Overrides the internal iteration index
|
| 810 |
+
CUTLASS_HOST_DEVICE
|
| 811 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 812 |
+
|
| 813 |
+
/// Adds a pointer offset in units of Element
|
| 814 |
+
CUTLASS_HOST_DEVICE
|
| 815 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 816 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 820 |
+
/// tiles
|
| 821 |
+
CUTLASS_HOST_DEVICE
|
| 822 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 823 |
+
iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
|
| 824 |
+
}
|
| 825 |
+
|
| 826 |
+
/// Returns a pointer
|
| 827 |
+
CUTLASS_HOST_DEVICE
|
| 828 |
+
AccessType *get() const {
|
| 829 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 830 |
+
}
|
| 831 |
+
|
| 832 |
+
/// Advances to the next tile in memory.
|
| 833 |
+
///
|
| 834 |
+
/// The first time this method is called, predicates are updated, and the
|
| 835 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 836 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 837 |
+
/// pointer.
|
| 838 |
+
CUTLASS_HOST_DEVICE
|
| 839 |
+
PredicatedTileAccessIteratorTriangularMatrix &operator++() {
|
| 840 |
+
++iterator_;
|
| 841 |
+
return *this;
|
| 842 |
+
}
|
| 843 |
+
|
| 844 |
+
/// Advances to the next tile in memory.
|
| 845 |
+
///
|
| 846 |
+
/// The first time this method is called, predicates are updated, and the
|
| 847 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 848 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 849 |
+
/// pointer.
|
| 850 |
+
CUTLASS_HOST_DEVICE
|
| 851 |
+
PredicatedTileAccessIteratorTriangularMatrix operator++(int) {
|
| 852 |
+
PredicatedTileAccessIteratorTriangularMatrix self(*this);
|
| 853 |
+
operator++();
|
| 854 |
+
return self;
|
| 855 |
+
}
|
| 856 |
+
|
| 857 |
+
/// Clears the predicate set efficiently
|
| 858 |
+
CUTLASS_HOST_DEVICE
|
| 859 |
+
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
| 860 |
+
|
| 861 |
+
/// Clears the predicate set efficiently
|
| 862 |
+
CUTLASS_HOST_DEVICE
|
| 863 |
+
void enable_mask() { iterator_.enable_mask(); }
|
| 864 |
+
|
| 865 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 866 |
+
CUTLASS_HOST_DEVICE
|
| 867 |
+
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
| 868 |
+
|
| 869 |
+
/// Gets the mask
|
| 870 |
+
CUTLASS_HOST_DEVICE
|
| 871 |
+
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
| 872 |
+
|
| 873 |
+
/// Return if the address in on the diagonal
|
| 874 |
+
CUTLASS_HOST_DEVICE
|
| 875 |
+
bool getOnDiag() {
|
| 876 |
+
return iterator_.getOnDiag();
|
| 877 |
+
}
|
| 878 |
+
|
| 879 |
+
/// Returns whether access is valid or not
|
| 880 |
+
CUTLASS_HOST_DEVICE
|
| 881 |
+
bool valid() {
|
| 882 |
+
return iterator_.valid();
|
| 883 |
+
}
|
| 884 |
+
};
|
| 885 |
+
|
| 886 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 887 |
+
|
| 888 |
+
} // namespace threadblock
|
| 889 |
+
} // namespace transform
|
| 890 |
+
} // namespace cutlass
|
| 891 |
+
|
| 892 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h
ADDED
|
@@ -0,0 +1,1887 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates implementing loading of tiles from pitch-linear rank=2 tensors.
|
| 33 |
+
|
| 34 |
+
This iterator uses masks to guard out-of-bounds accesses. The first tile this
|
| 35 |
+
iterator visits maybe partial, then the remaining tiles are complete. So, we
|
| 36 |
+
only need to compute the predicates twice, once before the first tile and
|
| 37 |
+
once for the remaining full tiles which can share the same predicates.
|
| 38 |
+
|
| 39 |
+
A precomputed "Params" object minimizes the amount of state that must be stored in registers,
|
| 40 |
+
and integer addition is used to advance the pointer through memory.
|
| 41 |
+
*/
|
| 42 |
+
|
| 43 |
+
#pragma once
|
| 44 |
+
|
| 45 |
+
#include "cutlass/arch/memory.h"
|
| 46 |
+
#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h"
|
| 47 |
+
|
| 48 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
namespace cutlass {
|
| 51 |
+
namespace transform {
|
| 52 |
+
namespace threadblock {
|
| 53 |
+
|
| 54 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 55 |
+
|
| 56 |
+
/// PredicatedTileIterator
|
| 57 |
+
///
|
| 58 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 59 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 60 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 61 |
+
/// MaskedTileIteratorConcept
|
| 62 |
+
///
|
| 63 |
+
/// Regular tile iterator using a precomputed control structure to minimize register liveness
|
| 64 |
+
/// and integer arithmetic.
|
| 65 |
+
///
|
| 66 |
+
/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed.
|
| 67 |
+
///
|
| 68 |
+
/// Base pointer and tensor extents may be specified at the time the iterator is constructed.
|
| 69 |
+
/// Subsequently, they are assumed to be immutable.
|
| 70 |
+
///
|
| 71 |
+
/// Adding a logical coordinate offset may be performed at the time the iterator is constructed.
|
| 72 |
+
/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive.
|
| 73 |
+
///
|
| 74 |
+
/// Visitation order is intended to first visit a "residual" tile that may be partially full in
|
| 75 |
+
/// both the advance dimension and the steady-state dimension. This is assumed to be the last
|
| 76 |
+
/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to
|
| 77 |
+
/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent
|
| 78 |
+
/// accesses may be performed without updating internal predicates and are efficient in terms of
|
| 79 |
+
/// live register state and pointer arithmetic instructions.
|
| 80 |
+
///
|
| 81 |
+
/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once
|
| 82 |
+
/// outside any looping structure to minimize integer arithmetic.
|
| 83 |
+
///
|
| 84 |
+
/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing
|
| 85 |
+
/// the iterator.
|
| 86 |
+
///
|
| 87 |
+
///
|
| 88 |
+
/// Example:
|
| 89 |
+
///
|
| 90 |
+
/// An efficient pipeline structure may be constructed as follows:
|
| 91 |
+
///
|
| 92 |
+
// template <typename Iterator>
|
| 93 |
+
// __global__ void kernel(
|
| 94 |
+
// typename Iterator::Params params,
|
| 95 |
+
// typename Iterator::Element *ptr,
|
| 96 |
+
// TensorCoord extent) {
|
| 97 |
+
//
|
| 98 |
+
// typename Iterator::Fragment fragment;
|
| 99 |
+
//
|
| 100 |
+
// TensorCoord threadblock_offset(0, 0);
|
| 101 |
+
//
|
| 102 |
+
// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets);
|
| 103 |
+
//
|
| 104 |
+
//
|
| 105 |
+
// fragment = *iter; // load "residue" tile first
|
| 106 |
+
// ++iter; // advance to first "steady state" tile and update internal masks
|
| 107 |
+
//
|
| 108 |
+
//
|
| 109 |
+
// #pragma unroll
|
| 110 |
+
// for (int i = Remaining - 1; i >= 0; --i) {
|
| 111 |
+
//
|
| 112 |
+
// f(fragment);
|
| 113 |
+
//
|
| 114 |
+
// if (!i) {
|
| 115 |
+
// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs.
|
| 116 |
+
// }
|
| 117 |
+
//
|
| 118 |
+
// fragment = *iter; // load tile during "steady state" phase
|
| 119 |
+
// ++iter; // advance to next tile - lightweight due to steady-state masks
|
| 120 |
+
// }
|
| 121 |
+
// }
|
| 122 |
+
//
|
| 123 |
+
// void host(TensorView<Element, 2, layout::PitchLinear> view) {
|
| 124 |
+
//
|
| 125 |
+
// using Iterator = transform::threadblock::PredicatedTileIterator;
|
| 126 |
+
//
|
| 127 |
+
// typename Iterator::Params params(view.layout());
|
| 128 |
+
//
|
| 129 |
+
// kernel<Iterator>(params, view.data());
|
| 130 |
+
// }
|
| 131 |
+
///
|
| 132 |
+
///
|
| 133 |
+
template <
|
| 134 |
+
typename Shape,
|
| 135 |
+
typename Element,
|
| 136 |
+
typename Layout,
|
| 137 |
+
int AdvanceRank,
|
| 138 |
+
typename ThreadMap,
|
| 139 |
+
int AccessSize = ThreadMap::kElementsPerAccess,
|
| 140 |
+
bool Gather = false,
|
| 141 |
+
typename PermuteLayout = layout::NoPermute
|
| 142 |
+
>
|
| 143 |
+
class PredicatedTileIterator;
|
| 144 |
+
|
| 145 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 146 |
+
|
| 147 |
+
/// Specialization of PredicatedTileIterator for pitch-linear data.
|
| 148 |
+
///
|
| 149 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 150 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 151 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 152 |
+
/// MaskedTileIteratorConcept
|
| 153 |
+
///
|
| 154 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 155 |
+
typename ThreadMap_, int AccessSize, bool Gather, typename PermuteLayout>
|
| 156 |
+
class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
|
| 157 |
+
ThreadMap_, AccessSize, Gather, PermuteLayout> {
|
| 158 |
+
public:
|
| 159 |
+
static_assert(
|
| 160 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 161 |
+
"Specialization for pitch-linear iterator may advance along the "
|
| 162 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 163 |
+
|
| 164 |
+
using Shape = Shape_;
|
| 165 |
+
using Element = Element_;
|
| 166 |
+
using Layout = layout::PitchLinear;
|
| 167 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 168 |
+
using ThreadMap = ThreadMap_;
|
| 169 |
+
|
| 170 |
+
using Index = typename Layout::Index;
|
| 171 |
+
using LongIndex = typename Layout::LongIndex;
|
| 172 |
+
|
| 173 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 174 |
+
using TensorView = TensorView<Element, Layout>;
|
| 175 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 176 |
+
|
| 177 |
+
using Pointer = Element *;
|
| 178 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 179 |
+
|
| 180 |
+
/// Type used for internal memory accesses
|
| 181 |
+
using AccessType = AlignedArray<Element, AccessSize, (AccessSize * sizeof_bits<Element>::value / 8)>;
|
| 182 |
+
|
| 183 |
+
/// Underlying iterator to compute the addresses
|
| 184 |
+
using TileAccessIterator =
|
| 185 |
+
PredicatedTileAccessIterator<Shape, Element, Layout, kAdvanceRank,
|
| 186 |
+
ThreadMap, AccessType, Gather, PermuteLayout>;
|
| 187 |
+
|
| 188 |
+
static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
|
| 189 |
+
|
| 190 |
+
/// Fragment object to be loaded or stored
|
| 191 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
|
| 192 |
+
ThreadMap::kElementsPerAccess>;
|
| 193 |
+
|
| 194 |
+
/// Predicate vector stores mask to guard accesses
|
| 195 |
+
using Mask = typename TileAccessIterator::Mask;
|
| 196 |
+
|
| 197 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 198 |
+
class Params {
|
| 199 |
+
public:
|
| 200 |
+
using Base = typename TileAccessIterator::Params::Base;
|
| 201 |
+
|
| 202 |
+
friend PredicatedTileIterator;
|
| 203 |
+
|
| 204 |
+
private:
|
| 205 |
+
/// Parameters object
|
| 206 |
+
typename TileAccessIterator::Params params_;
|
| 207 |
+
|
| 208 |
+
public:
|
| 209 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 210 |
+
CUTLASS_HOST_DEVICE
|
| 211 |
+
Params(Layout const &layout) : params_(layout) {}
|
| 212 |
+
|
| 213 |
+
/// Default constructor
|
| 214 |
+
Params() = default;
|
| 215 |
+
|
| 216 |
+
CUTLASS_HOST_DEVICE
|
| 217 |
+
Params(Base const &base)
|
| 218 |
+
: params_(base) {}
|
| 219 |
+
};
|
| 220 |
+
|
| 221 |
+
private:
|
| 222 |
+
/// Internal pointer type permits fast address arithmetic
|
| 223 |
+
using BytePointer = char *;
|
| 224 |
+
|
| 225 |
+
private:
|
| 226 |
+
//
|
| 227 |
+
// Data members
|
| 228 |
+
//
|
| 229 |
+
|
| 230 |
+
/// Data member to the tile access iterator
|
| 231 |
+
TileAccessIterator address_iterator_;
|
| 232 |
+
|
| 233 |
+
public:
|
| 234 |
+
|
| 235 |
+
/// Default constructor
|
| 236 |
+
PredicatedTileIterator() = default;
|
| 237 |
+
|
| 238 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 239 |
+
/// and thread ID
|
| 240 |
+
CUTLASS_HOST_DEVICE
|
| 241 |
+
PredicatedTileIterator(
|
| 242 |
+
/// Precomputed parameters object
|
| 243 |
+
Params const ¶ms,
|
| 244 |
+
/// Pointer to start of tensor
|
| 245 |
+
Pointer pointer,
|
| 246 |
+
/// Extent of tensor
|
| 247 |
+
TensorCoord extent,
|
| 248 |
+
/// ID of each participating thread
|
| 249 |
+
int thread_id,
|
| 250 |
+
/// Initial offset of threadblock
|
| 251 |
+
TensorCoord const &threadblock_offset,
|
| 252 |
+
/// Gather indices
|
| 253 |
+
int const *indices = nullptr)
|
| 254 |
+
: address_iterator_(params.params_, pointer, extent, thread_id,
|
| 255 |
+
threadblock_offset, indices) {}
|
| 256 |
+
|
| 257 |
+
/// Construct a PredicatedTileIterator with zero threadblock offset
|
| 258 |
+
CUTLASS_HOST_DEVICE
|
| 259 |
+
PredicatedTileIterator(
|
| 260 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 261 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 262 |
+
TensorCoord extent, ///< Extent of tensor
|
| 263 |
+
int thread_id ///< ID of each participating thread
|
| 264 |
+
)
|
| 265 |
+
: PredicatedTileIterator(params, pointer, extent, thread_id,
|
| 266 |
+
make_Coord(0, 0)) {}
|
| 267 |
+
|
| 268 |
+
/// Adds a pointer offset in units of Element
|
| 269 |
+
CUTLASS_HOST_DEVICE
|
| 270 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 271 |
+
address_iterator_.add_pointer_offset(pointer_offset);
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
/// Advances to the next tile in memory.
|
| 275 |
+
///
|
| 276 |
+
/// The first time this method is called, predicates are updated, and the
|
| 277 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 278 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 279 |
+
/// pointer.
|
| 280 |
+
CUTLASS_HOST_DEVICE
|
| 281 |
+
PredicatedTileIterator &operator++() {
|
| 282 |
+
if (kAdvanceRank)
|
| 283 |
+
address_iterator_.add_tile_offset({0, 1});
|
| 284 |
+
else
|
| 285 |
+
address_iterator_.add_tile_offset({1, 0});
|
| 286 |
+
|
| 287 |
+
return *this;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
/// Advances to the next tile in memory.
|
| 291 |
+
///
|
| 292 |
+
/// The first time this method is called, predicates are updated, and the
|
| 293 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 294 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 295 |
+
/// pointer.
|
| 296 |
+
CUTLASS_HOST_DEVICE
|
| 297 |
+
PredicatedTileIterator operator++(int) {
|
| 298 |
+
PredicatedTileIterator self(*this);
|
| 299 |
+
operator++();
|
| 300 |
+
return self;
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
/// Clears the predicate set efficiently
|
| 304 |
+
CUTLASS_HOST_DEVICE
|
| 305 |
+
void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
|
| 306 |
+
|
| 307 |
+
/// Clears the predicate set efficiently
|
| 308 |
+
CUTLASS_HOST_DEVICE
|
| 309 |
+
void enable_mask() { address_iterator_.enable_mask(); }
|
| 310 |
+
|
| 311 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 312 |
+
CUTLASS_HOST_DEVICE
|
| 313 |
+
void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); }
|
| 314 |
+
|
| 315 |
+
/// Gets the mask
|
| 316 |
+
CUTLASS_HOST_DEVICE
|
| 317 |
+
void get_mask(Mask &mask) { address_iterator_.get_mask(mask); }
|
| 318 |
+
|
| 319 |
+
CUTLASS_DEVICE
|
| 320 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 321 |
+
load_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
CUTLASS_DEVICE
|
| 325 |
+
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
| 326 |
+
|
| 327 |
+
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
| 328 |
+
|
| 329 |
+
CUTLASS_PRAGMA_UNROLL
|
| 330 |
+
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
| 331 |
+
CUTLASS_PRAGMA_UNROLL
|
| 332 |
+
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
| 333 |
+
|
| 334 |
+
CUTLASS_PRAGMA_UNROLL
|
| 335 |
+
for (int v = 0; v < kAccessesPerVector; ++v) {
|
| 336 |
+
|
| 337 |
+
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
| 338 |
+
|
| 339 |
+
address_iterator_.set_iteration_index(idx);
|
| 340 |
+
char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + byte_offset;
|
| 341 |
+
|
| 342 |
+
AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
|
| 343 |
+
|
| 344 |
+
cutlass::arch::global_load<AccessType,
|
| 345 |
+
sizeof(AccessType)
|
| 346 |
+
>(
|
| 347 |
+
frag_ptr[idx], access_ptr, address_iterator_.valid());
|
| 348 |
+
|
| 349 |
+
++address_iterator_;
|
| 350 |
+
}
|
| 351 |
+
}
|
| 352 |
+
}
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
/// Loads a fragment from memory
|
| 356 |
+
CUTLASS_DEVICE
|
| 357 |
+
void load(Fragment &frag) { load_with_byte_offset(frag, 0); }
|
| 358 |
+
|
| 359 |
+
/// Store a fragment to memory
|
| 360 |
+
CUTLASS_DEVICE
|
| 361 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 362 |
+
store_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
/// Store a fragment to memory
|
| 366 |
+
CUTLASS_DEVICE
|
| 367 |
+
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
| 368 |
+
address_iterator_.set_iteration_index(0);
|
| 369 |
+
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
| 370 |
+
|
| 371 |
+
CUTLASS_PRAGMA_UNROLL
|
| 372 |
+
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
| 373 |
+
CUTLASS_PRAGMA_UNROLL
|
| 374 |
+
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
| 375 |
+
CUTLASS_PRAGMA_UNROLL
|
| 376 |
+
for (int v = 0; v < kAccessesPerVector; ++v) {
|
| 377 |
+
|
| 378 |
+
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
| 379 |
+
|
| 380 |
+
char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
|
| 381 |
+
AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
|
| 382 |
+
|
| 383 |
+
if (address_iterator_.valid()) {
|
| 384 |
+
*access_ptr = frag_ptr[idx];
|
| 385 |
+
}
|
| 386 |
+
++address_iterator_;
|
| 387 |
+
}
|
| 388 |
+
}
|
| 389 |
+
}
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
/// Store a fragment to memory
|
| 393 |
+
CUTLASS_DEVICE
|
| 394 |
+
void store(Fragment const &frag) { store_with_byte_offset(frag, 0); }
|
| 395 |
+
};
|
| 396 |
+
|
| 397 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 398 |
+
|
| 399 |
+
/// Specialization of PredicatedTileIterator for column-major data.
|
| 400 |
+
///
|
| 401 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 402 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 403 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 404 |
+
/// MaskedTileIteratorConcept
|
| 405 |
+
///
|
| 406 |
+
template <
|
| 407 |
+
typename Shape_,
|
| 408 |
+
typename Element_,
|
| 409 |
+
int AdvanceRank,
|
| 410 |
+
typename ThreadMap_,
|
| 411 |
+
int AccessSize,
|
| 412 |
+
bool Gather,
|
| 413 |
+
typename PermuteLayout
|
| 414 |
+
>
|
| 415 |
+
class PredicatedTileIterator<Shape_, Element_, layout::ColumnMajor, AdvanceRank,
|
| 416 |
+
ThreadMap_, AccessSize, Gather, PermuteLayout> {
|
| 417 |
+
public:
|
| 418 |
+
|
| 419 |
+
static_assert(AdvanceRank == 0 || AdvanceRank == 1,
|
| 420 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 421 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 422 |
+
|
| 423 |
+
using Shape = Shape_;
|
| 424 |
+
using Element = Element_;
|
| 425 |
+
using Layout = layout::ColumnMajor;
|
| 426 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 427 |
+
using ThreadMap = ThreadMap_;
|
| 428 |
+
|
| 429 |
+
using Index = typename Layout::Index;
|
| 430 |
+
using LongIndex = typename Layout::LongIndex;
|
| 431 |
+
|
| 432 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 433 |
+
using TensorView = TensorView<Element, Layout>;
|
| 434 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 435 |
+
|
| 436 |
+
using Pointer = Element *;
|
| 437 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 438 |
+
|
| 439 |
+
using UnderlyingIterator = PredicatedTileIterator<
|
| 440 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
|
| 441 |
+
Element,
|
| 442 |
+
layout::PitchLinear,
|
| 443 |
+
(kAdvanceRank == 0 ? 0 : 1),
|
| 444 |
+
ThreadMap,
|
| 445 |
+
AccessSize,
|
| 446 |
+
Gather,
|
| 447 |
+
PermuteLayout
|
| 448 |
+
>;
|
| 449 |
+
|
| 450 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 451 |
+
|
| 452 |
+
/// Fragment object to be loaded or stored
|
| 453 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
| 454 |
+
|
| 455 |
+
/// Predicate vector stores mask to guard accesses
|
| 456 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 457 |
+
|
| 458 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 459 |
+
class Params {
|
| 460 |
+
private:
|
| 461 |
+
|
| 462 |
+
friend PredicatedTileIterator;
|
| 463 |
+
|
| 464 |
+
/// Parameters object
|
| 465 |
+
typename UnderlyingIterator::Params params_;
|
| 466 |
+
|
| 467 |
+
public:
|
| 468 |
+
|
| 469 |
+
/// Default constructor
|
| 470 |
+
Params() = default;
|
| 471 |
+
|
| 472 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 473 |
+
CUTLASS_HOST_DEVICE
|
| 474 |
+
Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0)))
|
| 475 |
+
{}
|
| 476 |
+
|
| 477 |
+
CUTLASS_HOST_DEVICE
|
| 478 |
+
Params(typename UnderlyingIterator::Params::Base const &base)
|
| 479 |
+
: params_(base) {}
|
| 480 |
+
};
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
private:
|
| 484 |
+
|
| 485 |
+
//
|
| 486 |
+
// Data members
|
| 487 |
+
//
|
| 488 |
+
|
| 489 |
+
/// Underlying pitch-linear tile iterator
|
| 490 |
+
UnderlyingIterator iterator_;
|
| 491 |
+
|
| 492 |
+
public:
|
| 493 |
+
|
| 494 |
+
/// Default constructor
|
| 495 |
+
PredicatedTileIterator() = default;
|
| 496 |
+
|
| 497 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
|
| 498 |
+
CUTLASS_HOST_DEVICE
|
| 499 |
+
PredicatedTileIterator(
|
| 500 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 501 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 502 |
+
TensorCoord extent, ///< Extent of tensor
|
| 503 |
+
int thread_id, ///< ID of each participating thread
|
| 504 |
+
TensorCoord const &threadblock_offset, ///< Initial offset of threadblock
|
| 505 |
+
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
| 506 |
+
):
|
| 507 |
+
iterator_(
|
| 508 |
+
params.params_,
|
| 509 |
+
pointer,
|
| 510 |
+
layout::PitchLinearCoord(extent.row(), extent.column()),
|
| 511 |
+
thread_id,
|
| 512 |
+
layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()),
|
| 513 |
+
indices)
|
| 514 |
+
{ }
|
| 515 |
+
|
| 516 |
+
/// Construct a PredicatedTileIterator with zero threadblock offset
|
| 517 |
+
CUTLASS_HOST_DEVICE
|
| 518 |
+
PredicatedTileIterator(
|
| 519 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 520 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 521 |
+
TensorCoord extent, ///< Extent of tensor
|
| 522 |
+
int thread_id ///< ID of each participating thread
|
| 523 |
+
): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
|
| 524 |
+
|
| 525 |
+
/// Adds a pointer offset in units of Element
|
| 526 |
+
CUTLASS_HOST_DEVICE
|
| 527 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 528 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 529 |
+
}
|
| 530 |
+
|
| 531 |
+
/// Advances to the next tile in memory.
|
| 532 |
+
///
|
| 533 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 534 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 535 |
+
/// are lightweight and must only update the internal pointer.
|
| 536 |
+
CUTLASS_HOST_DEVICE
|
| 537 |
+
PredicatedTileIterator &operator++() {
|
| 538 |
+
++iterator_;
|
| 539 |
+
return *this;
|
| 540 |
+
}
|
| 541 |
+
|
| 542 |
+
/// Advances to the next tile in memory.
|
| 543 |
+
///
|
| 544 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 545 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 546 |
+
/// are lightweight and must only update the internal pointer.
|
| 547 |
+
CUTLASS_HOST_DEVICE
|
| 548 |
+
PredicatedTileIterator operator++(int) {
|
| 549 |
+
PredicatedTileIterator self(*this);
|
| 550 |
+
operator++();
|
| 551 |
+
return self;
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
/// Clears the predicate set efficiently
|
| 555 |
+
CUTLASS_HOST_DEVICE
|
| 556 |
+
void clear_mask(bool enable = true) {
|
| 557 |
+
iterator_.clear_mask(enable);
|
| 558 |
+
}
|
| 559 |
+
|
| 560 |
+
/// Clears the predicate set efficiently
|
| 561 |
+
CUTLASS_HOST_DEVICE
|
| 562 |
+
void enable_mask() {
|
| 563 |
+
iterator_.enable_mask();
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 567 |
+
CUTLASS_HOST_DEVICE
|
| 568 |
+
void set_mask(Mask const &mask) {
|
| 569 |
+
iterator_.set_mask(mask);
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
/// Gets the mask
|
| 573 |
+
CUTLASS_HOST_DEVICE
|
| 574 |
+
void get_mask(Mask &mask) {
|
| 575 |
+
iterator_.get_mask(mask);
|
| 576 |
+
}
|
| 577 |
+
|
| 578 |
+
/// Loads a fragment from memory
|
| 579 |
+
CUTLASS_DEVICE
|
| 580 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 581 |
+
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 582 |
+
}
|
| 583 |
+
|
| 584 |
+
/// Loads a fragment from memory
|
| 585 |
+
CUTLASS_DEVICE
|
| 586 |
+
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
| 587 |
+
iterator_.load_with_byte_offset(frag, byte_offset);
|
| 588 |
+
}
|
| 589 |
+
|
| 590 |
+
/// Loads a fragment from memory
|
| 591 |
+
CUTLASS_DEVICE
|
| 592 |
+
void load(Fragment &frag) {
|
| 593 |
+
load_with_pointer_offset(frag, 0);
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
/// Store a fragment to memory
|
| 597 |
+
CUTLASS_DEVICE
|
| 598 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 599 |
+
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
/// Store a fragment to memory
|
| 603 |
+
CUTLASS_DEVICE
|
| 604 |
+
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
| 605 |
+
iterator_.store_with_byte_offset(frag, byte_offset);
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
/// Store a fragment to memory
|
| 609 |
+
CUTLASS_DEVICE
|
| 610 |
+
void store(Fragment const &frag) {
|
| 611 |
+
store_with_pointer_offset(frag, 0);
|
| 612 |
+
}
|
| 613 |
+
};
|
| 614 |
+
|
| 615 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 616 |
+
|
| 617 |
+
/// Specialization of PredicatedTileIterator for row-major data.
|
| 618 |
+
///
|
| 619 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 620 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 621 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 622 |
+
/// MaskedTileIteratorConcept
|
| 623 |
+
///
|
| 624 |
+
template <
|
| 625 |
+
typename Shape_,
|
| 626 |
+
typename Element_,
|
| 627 |
+
int AdvanceRank,
|
| 628 |
+
typename ThreadMap_,
|
| 629 |
+
int AccessSize,
|
| 630 |
+
bool Gather,
|
| 631 |
+
typename PermuteLayout
|
| 632 |
+
>
|
| 633 |
+
class PredicatedTileIterator<Shape_, Element_, layout::RowMajor, AdvanceRank,
|
| 634 |
+
ThreadMap_, AccessSize, Gather, PermuteLayout> {
|
| 635 |
+
public:
|
| 636 |
+
|
| 637 |
+
static_assert(AdvanceRank == 0 || AdvanceRank == 1,
|
| 638 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 639 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 640 |
+
|
| 641 |
+
using Shape = Shape_;
|
| 642 |
+
using Element = Element_;
|
| 643 |
+
using Layout = layout::RowMajor;
|
| 644 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 645 |
+
using ThreadMap = ThreadMap_;
|
| 646 |
+
|
| 647 |
+
using Index = typename Layout::Index;
|
| 648 |
+
using LongIndex = typename Layout::LongIndex;
|
| 649 |
+
|
| 650 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 651 |
+
using TensorView = TensorView<Element, Layout>;
|
| 652 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 653 |
+
|
| 654 |
+
using Pointer = Element *;
|
| 655 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 656 |
+
|
| 657 |
+
using UnderlyingIterator = PredicatedTileIterator<
|
| 658 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
|
| 659 |
+
Element,
|
| 660 |
+
layout::PitchLinear,
|
| 661 |
+
(kAdvanceRank == 0 ? 1 : 0),
|
| 662 |
+
ThreadMap,
|
| 663 |
+
AccessSize,
|
| 664 |
+
Gather,
|
| 665 |
+
PermuteLayout
|
| 666 |
+
>;
|
| 667 |
+
|
| 668 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 669 |
+
|
| 670 |
+
/// Fragment object to be loaded or stored
|
| 671 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
| 672 |
+
|
| 673 |
+
/// Predicate vector stores mask to guard accesses
|
| 674 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 675 |
+
|
| 676 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 677 |
+
class Params {
|
| 678 |
+
private:
|
| 679 |
+
|
| 680 |
+
friend PredicatedTileIterator;
|
| 681 |
+
|
| 682 |
+
/// Parameters object
|
| 683 |
+
typename UnderlyingIterator::Params params_;
|
| 684 |
+
|
| 685 |
+
public:
|
| 686 |
+
|
| 687 |
+
/// Default constructor
|
| 688 |
+
Params() = default;
|
| 689 |
+
|
| 690 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 691 |
+
CUTLASS_HOST_DEVICE
|
| 692 |
+
Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {}
|
| 693 |
+
|
| 694 |
+
CUTLASS_HOST_DEVICE
|
| 695 |
+
Params(typename UnderlyingIterator::Params::Base const &base)
|
| 696 |
+
: params_(base) {}
|
| 697 |
+
|
| 698 |
+
};
|
| 699 |
+
|
| 700 |
+
private:
|
| 701 |
+
|
| 702 |
+
//
|
| 703 |
+
// Data members
|
| 704 |
+
//
|
| 705 |
+
|
| 706 |
+
/// Underlying pitch-linear tile iterator
|
| 707 |
+
UnderlyingIterator iterator_;
|
| 708 |
+
|
| 709 |
+
public:
|
| 710 |
+
|
| 711 |
+
/// Default constructor
|
| 712 |
+
PredicatedTileIterator() = default;
|
| 713 |
+
|
| 714 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
|
| 715 |
+
CUTLASS_HOST_DEVICE
|
| 716 |
+
PredicatedTileIterator(
|
| 717 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 718 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 719 |
+
TensorCoord extent, ///< Extent of tensor
|
| 720 |
+
int thread_id, ///< ID of each participating thread
|
| 721 |
+
TensorCoord const &threadblock_offset, ///< Initial offset of threadblock
|
| 722 |
+
int const *indices = nullptr ///< Gather indices
|
| 723 |
+
):
|
| 724 |
+
iterator_(
|
| 725 |
+
params.params_,
|
| 726 |
+
pointer,
|
| 727 |
+
layout::PitchLinearCoord(extent.column(), extent.row()),
|
| 728 |
+
thread_id,
|
| 729 |
+
layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()),
|
| 730 |
+
indices
|
| 731 |
+
) { }
|
| 732 |
+
|
| 733 |
+
/// Construct a PredicatedTileIterator with zero threadblock offset
|
| 734 |
+
CUTLASS_HOST_DEVICE
|
| 735 |
+
PredicatedTileIterator(
|
| 736 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 737 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 738 |
+
TensorCoord extent, ///< Extent of tensor
|
| 739 |
+
int thread_id ///< ID of each participating thread
|
| 740 |
+
): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
|
| 741 |
+
|
| 742 |
+
/// Adds a pointer offset in units of Element
|
| 743 |
+
CUTLASS_HOST_DEVICE
|
| 744 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 745 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 746 |
+
}
|
| 747 |
+
|
| 748 |
+
/// Advances to the next tile in memory.
|
| 749 |
+
///
|
| 750 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 751 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 752 |
+
/// are lightweight and must only update the internal pointer.
|
| 753 |
+
CUTLASS_HOST_DEVICE
|
| 754 |
+
PredicatedTileIterator &operator++() {
|
| 755 |
+
++iterator_;
|
| 756 |
+
return *this;
|
| 757 |
+
}
|
| 758 |
+
|
| 759 |
+
/// Advances to the next tile in memory.
|
| 760 |
+
///
|
| 761 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 762 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 763 |
+
/// are lightweight and must only update the internal pointer.
|
| 764 |
+
CUTLASS_HOST_DEVICE
|
| 765 |
+
PredicatedTileIterator operator++(int) {
|
| 766 |
+
PredicatedTileIterator self(*this);
|
| 767 |
+
operator++();
|
| 768 |
+
return self;
|
| 769 |
+
}
|
| 770 |
+
|
| 771 |
+
/// Clears the predicate set efficiently
|
| 772 |
+
CUTLASS_HOST_DEVICE
|
| 773 |
+
void clear_mask(bool enable = true) {
|
| 774 |
+
iterator_.clear_mask(enable);
|
| 775 |
+
}
|
| 776 |
+
|
| 777 |
+
/// Clears the predicate set efficiently
|
| 778 |
+
CUTLASS_HOST_DEVICE
|
| 779 |
+
void enable_mask() {
|
| 780 |
+
iterator_.enable_mask();
|
| 781 |
+
}
|
| 782 |
+
|
| 783 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 784 |
+
CUTLASS_HOST_DEVICE
|
| 785 |
+
void set_mask(Mask const &mask) {
|
| 786 |
+
iterator_.set_mask(mask);
|
| 787 |
+
}
|
| 788 |
+
|
| 789 |
+
/// Gets the mask
|
| 790 |
+
CUTLASS_HOST_DEVICE
|
| 791 |
+
void get_mask(Mask &mask) {
|
| 792 |
+
iterator_.get_mask(mask);
|
| 793 |
+
}
|
| 794 |
+
|
| 795 |
+
/// Loads a fragment from memory
|
| 796 |
+
CUTLASS_DEVICE
|
| 797 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 798 |
+
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 799 |
+
}
|
| 800 |
+
|
| 801 |
+
/// Loads a fragment from memory
|
| 802 |
+
CUTLASS_DEVICE
|
| 803 |
+
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
| 804 |
+
iterator_.load_with_byte_offset(frag, byte_offset);
|
| 805 |
+
}
|
| 806 |
+
|
| 807 |
+
/// Loads a fragment from memory
|
| 808 |
+
CUTLASS_DEVICE
|
| 809 |
+
void load(Fragment &frag) {
|
| 810 |
+
load_with_pointer_offset(frag, 0);
|
| 811 |
+
}
|
| 812 |
+
|
| 813 |
+
/// Store a fragment to memory
|
| 814 |
+
CUTLASS_DEVICE
|
| 815 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 816 |
+
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
/// Store a fragment to memory
|
| 820 |
+
CUTLASS_DEVICE
|
| 821 |
+
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
| 822 |
+
iterator_.store_with_byte_offset(frag, byte_offset);
|
| 823 |
+
}
|
| 824 |
+
|
| 825 |
+
/// Store a fragment to memory
|
| 826 |
+
CUTLASS_DEVICE
|
| 827 |
+
void store(Fragment const &frag) {
|
| 828 |
+
store_with_pointer_offset(frag, 0);
|
| 829 |
+
}
|
| 830 |
+
};
|
| 831 |
+
|
| 832 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 833 |
+
|
| 834 |
+
/// Specialization of PredicatedTileIterator for affine rank-2 data.
|
| 835 |
+
///
|
| 836 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 837 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 838 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 839 |
+
/// MaskedTileIteratorConcept
|
| 840 |
+
///
|
| 841 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 842 |
+
typename ThreadMap_, int AccessSize>
|
| 843 |
+
class PredicatedTileIterator<Shape_, Element_, layout::AffineRankN<2>, AdvanceRank,
|
| 844 |
+
ThreadMap_, AccessSize, false> {
|
| 845 |
+
public:
|
| 846 |
+
static_assert(
|
| 847 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 848 |
+
"Specialization for pitch-linear iterator may advance along the "
|
| 849 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 850 |
+
|
| 851 |
+
using Shape = Shape_;
|
| 852 |
+
using Element = Element_;
|
| 853 |
+
using Layout = layout::AffineRankN<2>;
|
| 854 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 855 |
+
using ThreadMap = ThreadMap_;
|
| 856 |
+
|
| 857 |
+
using Index = typename Layout::Index;
|
| 858 |
+
using LongIndex = typename Layout::LongIndex;
|
| 859 |
+
|
| 860 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 861 |
+
using TensorView = TensorView<Element, Layout>;
|
| 862 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 863 |
+
|
| 864 |
+
using Pointer = Element *;
|
| 865 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 866 |
+
|
| 867 |
+
/// Type used for internal memory accesses
|
| 868 |
+
using AccessType = AlignedArray<Element, AccessSize, (AccessSize * sizeof_bits<Element>::value / 8)>;
|
| 869 |
+
|
| 870 |
+
/// Underlying iterator to compute the addresses
|
| 871 |
+
using TileAccessIterator =
|
| 872 |
+
PredicatedTileAccessIterator<Shape, Element, Layout, kAdvanceRank,
|
| 873 |
+
ThreadMap, AccessType>;
|
| 874 |
+
|
| 875 |
+
static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
|
| 876 |
+
|
| 877 |
+
/// Fragment object to be loaded or stored
|
| 878 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
|
| 879 |
+
ThreadMap::kElementsPerAccess>;
|
| 880 |
+
|
| 881 |
+
/// Predicate vector stores mask to guard accesses
|
| 882 |
+
using Mask = typename TileAccessIterator::Mask;
|
| 883 |
+
|
| 884 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 885 |
+
class Params {
|
| 886 |
+
public:
|
| 887 |
+
|
| 888 |
+
friend PredicatedTileIterator;
|
| 889 |
+
|
| 890 |
+
private:
|
| 891 |
+
/// Parameters object
|
| 892 |
+
typename TileAccessIterator::Params params_;
|
| 893 |
+
|
| 894 |
+
public:
|
| 895 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 896 |
+
CUTLASS_HOST_DEVICE
|
| 897 |
+
Params(Layout const &layout) : params_(layout) {}
|
| 898 |
+
|
| 899 |
+
/// Default constructor
|
| 900 |
+
Params() = default;
|
| 901 |
+
};
|
| 902 |
+
|
| 903 |
+
private:
|
| 904 |
+
/// Internal pointer type permits fast address arithmetic
|
| 905 |
+
using BytePointer = char *;
|
| 906 |
+
|
| 907 |
+
private:
|
| 908 |
+
//
|
| 909 |
+
// Data members
|
| 910 |
+
//
|
| 911 |
+
|
| 912 |
+
/// Data member to the tile access iterator
|
| 913 |
+
TileAccessIterator address_iterator_;
|
| 914 |
+
|
| 915 |
+
public:
|
| 916 |
+
|
| 917 |
+
/// Default constructor
|
| 918 |
+
PredicatedTileIterator() = default;
|
| 919 |
+
|
| 920 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 921 |
+
/// and thread ID
|
| 922 |
+
CUTLASS_HOST_DEVICE
|
| 923 |
+
PredicatedTileIterator(
|
| 924 |
+
/// Precomputed parameters object
|
| 925 |
+
Params const ¶ms,
|
| 926 |
+
/// Pointer to start of tensor
|
| 927 |
+
Pointer pointer,
|
| 928 |
+
/// Extent of tensor
|
| 929 |
+
TensorCoord extent,
|
| 930 |
+
/// ID of each participating thread
|
| 931 |
+
int thread_id,
|
| 932 |
+
/// Initial offset of threadblock
|
| 933 |
+
TensorCoord const &threadblock_offset,
|
| 934 |
+
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
| 935 |
+
)
|
| 936 |
+
: address_iterator_(params.params_, pointer, extent, thread_id,
|
| 937 |
+
threadblock_offset) {}
|
| 938 |
+
|
| 939 |
+
/// Construct a PredicatedTileIterator with zero threadblock offset
|
| 940 |
+
CUTLASS_HOST_DEVICE
|
| 941 |
+
PredicatedTileIterator(
|
| 942 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 943 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 944 |
+
TensorCoord extent, ///< Extent of tensor
|
| 945 |
+
int thread_id ///< ID of each participating thread
|
| 946 |
+
)
|
| 947 |
+
: PredicatedTileIterator(params, pointer, extent, thread_id,
|
| 948 |
+
make_Coord(0, 0)) {}
|
| 949 |
+
|
| 950 |
+
/// Adds a pointer offset in units of Element
|
| 951 |
+
CUTLASS_HOST_DEVICE
|
| 952 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 953 |
+
address_iterator_.add_pointer_offset(pointer_offset);
|
| 954 |
+
}
|
| 955 |
+
|
| 956 |
+
/// Advances to the next tile in memory.
|
| 957 |
+
///
|
| 958 |
+
/// The first time this method is called, predicates are updated, and the
|
| 959 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 960 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 961 |
+
/// pointer.
|
| 962 |
+
CUTLASS_HOST_DEVICE
|
| 963 |
+
PredicatedTileIterator &operator++() {
|
| 964 |
+
if (kAdvanceRank)
|
| 965 |
+
address_iterator_.add_tile_offset(make_Coord(0, 1));
|
| 966 |
+
else
|
| 967 |
+
address_iterator_.add_tile_offset(make_Coord(1, 0));
|
| 968 |
+
|
| 969 |
+
return *this;
|
| 970 |
+
}
|
| 971 |
+
|
| 972 |
+
/// Advances to the next tile in memory.
|
| 973 |
+
///
|
| 974 |
+
/// The first time this method is called, predicates are updated, and the
|
| 975 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 976 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 977 |
+
/// pointer.
|
| 978 |
+
CUTLASS_HOST_DEVICE
|
| 979 |
+
PredicatedTileIterator operator++(int) {
|
| 980 |
+
PredicatedTileIterator self(*this);
|
| 981 |
+
operator++();
|
| 982 |
+
return self;
|
| 983 |
+
}
|
| 984 |
+
|
| 985 |
+
/// Clears the predicate set efficiently
|
| 986 |
+
CUTLASS_HOST_DEVICE
|
| 987 |
+
void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
|
| 988 |
+
|
| 989 |
+
/// Clears the predicate set efficiently
|
| 990 |
+
CUTLASS_HOST_DEVICE
|
| 991 |
+
void enable_mask() { address_iterator_.enable_mask(); }
|
| 992 |
+
|
| 993 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 994 |
+
CUTLASS_HOST_DEVICE
|
| 995 |
+
void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); }
|
| 996 |
+
|
| 997 |
+
/// Gets the mask
|
| 998 |
+
CUTLASS_HOST_DEVICE
|
| 999 |
+
void get_mask(Mask &mask) { address_iterator_.get_mask(mask); }
|
| 1000 |
+
|
| 1001 |
+
CUTLASS_DEVICE
|
| 1002 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 1003 |
+
load_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
| 1004 |
+
}
|
| 1005 |
+
|
| 1006 |
+
CUTLASS_DEVICE
|
| 1007 |
+
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
| 1008 |
+
|
| 1009 |
+
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
| 1010 |
+
|
| 1011 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1012 |
+
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
| 1013 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1014 |
+
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
| 1015 |
+
|
| 1016 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1017 |
+
for (int v = 0; v < kAccessesPerVector; ++v) {
|
| 1018 |
+
|
| 1019 |
+
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
| 1020 |
+
|
| 1021 |
+
address_iterator_.set_iteration_index(idx);
|
| 1022 |
+
char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + byte_offset;
|
| 1023 |
+
|
| 1024 |
+
AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
|
| 1025 |
+
|
| 1026 |
+
cutlass::arch::global_load<AccessType,
|
| 1027 |
+
sizeof(AccessType)
|
| 1028 |
+
>(
|
| 1029 |
+
frag_ptr[idx], access_ptr, address_iterator_.valid());
|
| 1030 |
+
|
| 1031 |
+
++address_iterator_;
|
| 1032 |
+
}
|
| 1033 |
+
}
|
| 1034 |
+
}
|
| 1035 |
+
}
|
| 1036 |
+
|
| 1037 |
+
/// Loads a fragment from memory
|
| 1038 |
+
CUTLASS_DEVICE
|
| 1039 |
+
void load(Fragment &frag) { load_with_byte_offset(frag, 0); }
|
| 1040 |
+
|
| 1041 |
+
/// Store a fragment to memory
|
| 1042 |
+
CUTLASS_DEVICE
|
| 1043 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 1044 |
+
store_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
| 1045 |
+
}
|
| 1046 |
+
|
| 1047 |
+
/// Store a fragment to memory
|
| 1048 |
+
CUTLASS_DEVICE
|
| 1049 |
+
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
| 1050 |
+
address_iterator_.set_iteration_index(0);
|
| 1051 |
+
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
| 1052 |
+
|
| 1053 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1054 |
+
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
| 1055 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1056 |
+
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
| 1057 |
+
CUTLASS_PRAGMA_UNROLL
|
| 1058 |
+
for (int v = 0; v < kAccessesPerVector; ++v) {
|
| 1059 |
+
|
| 1060 |
+
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
| 1061 |
+
|
| 1062 |
+
char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
|
| 1063 |
+
AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
|
| 1064 |
+
|
| 1065 |
+
if (address_iterator_.valid()) {
|
| 1066 |
+
*access_ptr = frag_ptr[idx];
|
| 1067 |
+
}
|
| 1068 |
+
++address_iterator_;
|
| 1069 |
+
}
|
| 1070 |
+
}
|
| 1071 |
+
}
|
| 1072 |
+
}
|
| 1073 |
+
|
| 1074 |
+
/// Store a fragment to memory
|
| 1075 |
+
CUTLASS_DEVICE
|
| 1076 |
+
void store(Fragment const &frag) { store_with_byte_offset(frag, 0); }
|
| 1077 |
+
};
|
| 1078 |
+
|
| 1079 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1080 |
+
|
| 1081 |
+
/// Specialization of PredicatedTileIterator for affine rank 2 column-major data.
|
| 1082 |
+
///
|
| 1083 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 1084 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 1085 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 1086 |
+
/// MaskedTileIteratorConcept
|
| 1087 |
+
///
|
| 1088 |
+
template <
|
| 1089 |
+
typename Shape_,
|
| 1090 |
+
typename Element_,
|
| 1091 |
+
int AdvanceRank,
|
| 1092 |
+
typename ThreadMap_,
|
| 1093 |
+
int AccessSize
|
| 1094 |
+
>
|
| 1095 |
+
class PredicatedTileIterator<Shape_, Element_, layout::AffineRank2ColumnMajor, AdvanceRank, ThreadMap_, AccessSize, false> {
|
| 1096 |
+
public:
|
| 1097 |
+
|
| 1098 |
+
static_assert(AdvanceRank == 0 || AdvanceRank == 1,
|
| 1099 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 1100 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 1101 |
+
|
| 1102 |
+
using Shape = Shape_;
|
| 1103 |
+
using Element = Element_;
|
| 1104 |
+
using Layout = layout::AffineRank2ColumnMajor;
|
| 1105 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 1106 |
+
using ThreadMap = ThreadMap_;
|
| 1107 |
+
|
| 1108 |
+
using Index = typename Layout::Index;
|
| 1109 |
+
using LongIndex = typename Layout::LongIndex;
|
| 1110 |
+
|
| 1111 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 1112 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1113 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 1114 |
+
|
| 1115 |
+
using Pointer = Element *;
|
| 1116 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 1117 |
+
|
| 1118 |
+
// Map to the underlying AffineRankN<2> layout
|
| 1119 |
+
using UnderlyingIterator = PredicatedTileIterator<
|
| 1120 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
|
| 1121 |
+
Element,
|
| 1122 |
+
layout::AffineRankN<2>,
|
| 1123 |
+
(kAdvanceRank == 0 ? 0 : 1),
|
| 1124 |
+
ThreadMap,
|
| 1125 |
+
AccessSize
|
| 1126 |
+
>;
|
| 1127 |
+
|
| 1128 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 1129 |
+
|
| 1130 |
+
/// Fragment object to be loaded or stored
|
| 1131 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
| 1132 |
+
|
| 1133 |
+
/// Predicate vector stores mask to guard accesses
|
| 1134 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 1135 |
+
|
| 1136 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 1137 |
+
class Params {
|
| 1138 |
+
private:
|
| 1139 |
+
|
| 1140 |
+
friend PredicatedTileIterator;
|
| 1141 |
+
|
| 1142 |
+
/// Parameters object
|
| 1143 |
+
typename UnderlyingIterator::Params params_;
|
| 1144 |
+
|
| 1145 |
+
public:
|
| 1146 |
+
|
| 1147 |
+
/// Default constructor
|
| 1148 |
+
Params() = default;
|
| 1149 |
+
|
| 1150 |
+
/// Construct the Params object given an AffineRankN<2> tensor's layout
|
| 1151 |
+
CUTLASS_HOST_DEVICE
|
| 1152 |
+
Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1)))
|
| 1153 |
+
{}
|
| 1154 |
+
};
|
| 1155 |
+
|
| 1156 |
+
private:
|
| 1157 |
+
|
| 1158 |
+
//
|
| 1159 |
+
// Data members
|
| 1160 |
+
//
|
| 1161 |
+
|
| 1162 |
+
/// Underlying AffineRankN<2> tile iterator
|
| 1163 |
+
UnderlyingIterator iterator_;
|
| 1164 |
+
|
| 1165 |
+
public:
|
| 1166 |
+
|
| 1167 |
+
/// Default constructor
|
| 1168 |
+
PredicatedTileIterator() = default;
|
| 1169 |
+
|
| 1170 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
|
| 1171 |
+
CUTLASS_HOST_DEVICE
|
| 1172 |
+
PredicatedTileIterator(
|
| 1173 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 1174 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 1175 |
+
TensorCoord extent, ///< Extent of tensor
|
| 1176 |
+
int thread_id, ///< ID of each participating thread
|
| 1177 |
+
TensorCoord const &threadblock_offset, ///< Initial offset of threadblock
|
| 1178 |
+
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
| 1179 |
+
):
|
| 1180 |
+
iterator_(
|
| 1181 |
+
params.params_,
|
| 1182 |
+
pointer,
|
| 1183 |
+
layout::PitchLinearCoord(extent.row(), extent.column()),
|
| 1184 |
+
thread_id,
|
| 1185 |
+
layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column())
|
| 1186 |
+
) { }
|
| 1187 |
+
|
| 1188 |
+
/// Construct a PredicatedTileIterator with zero threadblock offset
|
| 1189 |
+
CUTLASS_HOST_DEVICE
|
| 1190 |
+
PredicatedTileIterator(
|
| 1191 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 1192 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 1193 |
+
TensorCoord extent, ///< Extent of tensor
|
| 1194 |
+
int thread_id ///< ID of each participating thread
|
| 1195 |
+
): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
|
| 1196 |
+
|
| 1197 |
+
/// Adds a pointer offset in units of Element
|
| 1198 |
+
CUTLASS_HOST_DEVICE
|
| 1199 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1200 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 1201 |
+
}
|
| 1202 |
+
|
| 1203 |
+
/// Advances to the next tile in memory.
|
| 1204 |
+
///
|
| 1205 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 1206 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 1207 |
+
/// are lightweight and must only update the internal pointer.
|
| 1208 |
+
CUTLASS_HOST_DEVICE
|
| 1209 |
+
PredicatedTileIterator &operator++() {
|
| 1210 |
+
++iterator_;
|
| 1211 |
+
return *this;
|
| 1212 |
+
}
|
| 1213 |
+
|
| 1214 |
+
/// Advances to the next tile in memory.
|
| 1215 |
+
///
|
| 1216 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 1217 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 1218 |
+
/// are lightweight and must only update the internal pointer.
|
| 1219 |
+
CUTLASS_HOST_DEVICE
|
| 1220 |
+
PredicatedTileIterator operator++(int) {
|
| 1221 |
+
PredicatedTileIterator self(*this);
|
| 1222 |
+
operator++();
|
| 1223 |
+
return self;
|
| 1224 |
+
}
|
| 1225 |
+
|
| 1226 |
+
/// Clears the predicate set efficiently
|
| 1227 |
+
CUTLASS_HOST_DEVICE
|
| 1228 |
+
void clear_mask(bool enable = true) {
|
| 1229 |
+
iterator_.clear_mask(enable);
|
| 1230 |
+
}
|
| 1231 |
+
|
| 1232 |
+
/// Clears the predicate set efficiently
|
| 1233 |
+
CUTLASS_HOST_DEVICE
|
| 1234 |
+
void enable_mask() {
|
| 1235 |
+
iterator_.enable_mask();
|
| 1236 |
+
}
|
| 1237 |
+
|
| 1238 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 1239 |
+
CUTLASS_HOST_DEVICE
|
| 1240 |
+
void set_mask(Mask const &mask) {
|
| 1241 |
+
iterator_.set_mask(mask);
|
| 1242 |
+
}
|
| 1243 |
+
|
| 1244 |
+
/// Gets the mask
|
| 1245 |
+
CUTLASS_HOST_DEVICE
|
| 1246 |
+
void get_mask(Mask &mask) {
|
| 1247 |
+
iterator_.get_mask(mask);
|
| 1248 |
+
}
|
| 1249 |
+
|
| 1250 |
+
/// Loads a fragment from memory
|
| 1251 |
+
CUTLASS_DEVICE
|
| 1252 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 1253 |
+
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 1254 |
+
}
|
| 1255 |
+
|
| 1256 |
+
/// Loads a fragment from memory
|
| 1257 |
+
CUTLASS_DEVICE
|
| 1258 |
+
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
| 1259 |
+
iterator_.load_with_byte_offset(frag, byte_offset);
|
| 1260 |
+
}
|
| 1261 |
+
|
| 1262 |
+
/// Loads a fragment from memory
|
| 1263 |
+
CUTLASS_DEVICE
|
| 1264 |
+
void load(Fragment &frag) {
|
| 1265 |
+
load_with_pointer_offset(frag, 0);
|
| 1266 |
+
}
|
| 1267 |
+
|
| 1268 |
+
/// Store a fragment to memory
|
| 1269 |
+
CUTLASS_DEVICE
|
| 1270 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 1271 |
+
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 1272 |
+
}
|
| 1273 |
+
|
| 1274 |
+
/// Store a fragment to memory
|
| 1275 |
+
CUTLASS_DEVICE
|
| 1276 |
+
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
| 1277 |
+
iterator_.store_with_byte_offset(frag, byte_offset);
|
| 1278 |
+
}
|
| 1279 |
+
|
| 1280 |
+
/// Store a fragment to memory
|
| 1281 |
+
CUTLASS_DEVICE
|
| 1282 |
+
void store(Fragment const &frag) {
|
| 1283 |
+
store_with_pointer_offset(frag, 0);
|
| 1284 |
+
}
|
| 1285 |
+
};
|
| 1286 |
+
|
| 1287 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1288 |
+
|
| 1289 |
+
/// Specialization of PredicatedTileIterator for affine rank 2 row-major data.
|
| 1290 |
+
///
|
| 1291 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 1292 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 1293 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 1294 |
+
/// MaskedTileIteratorConcept
|
| 1295 |
+
///
|
| 1296 |
+
template <
|
| 1297 |
+
typename Shape_,
|
| 1298 |
+
typename Element_,
|
| 1299 |
+
int AdvanceRank,
|
| 1300 |
+
typename ThreadMap_,
|
| 1301 |
+
int AccessSize
|
| 1302 |
+
>
|
| 1303 |
+
class PredicatedTileIterator<Shape_, Element_, layout::AffineRank2RowMajor, AdvanceRank, ThreadMap_, AccessSize, false> {
|
| 1304 |
+
public:
|
| 1305 |
+
|
| 1306 |
+
static_assert(AdvanceRank == 0 || AdvanceRank == 1,
|
| 1307 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 1308 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 1309 |
+
|
| 1310 |
+
using Shape = Shape_;
|
| 1311 |
+
using Element = Element_;
|
| 1312 |
+
using Layout = layout::AffineRank2RowMajor;
|
| 1313 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 1314 |
+
using ThreadMap = ThreadMap_;
|
| 1315 |
+
|
| 1316 |
+
using Index = typename Layout::Index;
|
| 1317 |
+
using LongIndex = typename Layout::LongIndex;
|
| 1318 |
+
|
| 1319 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 1320 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1321 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 1322 |
+
|
| 1323 |
+
using Pointer = Element *;
|
| 1324 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 1325 |
+
|
| 1326 |
+
// Map to the underlying AffineRankN<2> layout
|
| 1327 |
+
using UnderlyingIterator = PredicatedTileIterator<
|
| 1328 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
|
| 1329 |
+
Element,
|
| 1330 |
+
layout::AffineRankN<2>,
|
| 1331 |
+
(kAdvanceRank == 0 ? 1 : 0),
|
| 1332 |
+
ThreadMap,
|
| 1333 |
+
AccessSize
|
| 1334 |
+
>;
|
| 1335 |
+
|
| 1336 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 1337 |
+
|
| 1338 |
+
/// Fragment object to be loaded or stored
|
| 1339 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
| 1340 |
+
|
| 1341 |
+
/// Predicate vector stores mask to guard accesses
|
| 1342 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 1343 |
+
|
| 1344 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 1345 |
+
class Params {
|
| 1346 |
+
private:
|
| 1347 |
+
|
| 1348 |
+
friend PredicatedTileIterator;
|
| 1349 |
+
|
| 1350 |
+
/// Parameters object
|
| 1351 |
+
typename UnderlyingIterator::Params params_;
|
| 1352 |
+
|
| 1353 |
+
public:
|
| 1354 |
+
|
| 1355 |
+
/// Default constructor
|
| 1356 |
+
Params() = default;
|
| 1357 |
+
|
| 1358 |
+
/// Construct the Params object given an AffineRankN<2> tensor's layout
|
| 1359 |
+
CUTLASS_HOST_DEVICE
|
| 1360 |
+
Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {}
|
| 1361 |
+
};
|
| 1362 |
+
|
| 1363 |
+
|
| 1364 |
+
private:
|
| 1365 |
+
|
| 1366 |
+
//
|
| 1367 |
+
// Data members
|
| 1368 |
+
//
|
| 1369 |
+
|
| 1370 |
+
/// Underlying AffineRankN<2> tile iterator
|
| 1371 |
+
UnderlyingIterator iterator_;
|
| 1372 |
+
|
| 1373 |
+
public:
|
| 1374 |
+
|
| 1375 |
+
/// Default constructor
|
| 1376 |
+
PredicatedTileIterator() = default;
|
| 1377 |
+
|
| 1378 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
|
| 1379 |
+
CUTLASS_HOST_DEVICE
|
| 1380 |
+
PredicatedTileIterator(
|
| 1381 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 1382 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 1383 |
+
TensorCoord extent, ///< Extent of tensor
|
| 1384 |
+
int thread_id, ///< ID of each participating thread
|
| 1385 |
+
TensorCoord const &threadblock_offset, ///< Initial offset of threadblock
|
| 1386 |
+
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
| 1387 |
+
):
|
| 1388 |
+
iterator_(
|
| 1389 |
+
params.params_,
|
| 1390 |
+
pointer,
|
| 1391 |
+
layout::PitchLinearCoord(extent.column(), extent.row()),
|
| 1392 |
+
thread_id,
|
| 1393 |
+
layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row())
|
| 1394 |
+
) { }
|
| 1395 |
+
|
| 1396 |
+
/// Construct a PredicatedTileIterator with zero threadblock offset
|
| 1397 |
+
CUTLASS_HOST_DEVICE
|
| 1398 |
+
PredicatedTileIterator(
|
| 1399 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 1400 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 1401 |
+
TensorCoord extent, ///< Extent of tensor
|
| 1402 |
+
int thread_id ///< ID of each participating thread
|
| 1403 |
+
): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
|
| 1404 |
+
|
| 1405 |
+
/// Adds a pointer offset in units of Element
|
| 1406 |
+
CUTLASS_HOST_DEVICE
|
| 1407 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1408 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 1409 |
+
}
|
| 1410 |
+
|
| 1411 |
+
/// Advances to the next tile in memory.
|
| 1412 |
+
///
|
| 1413 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 1414 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 1415 |
+
/// are lightweight and must only update the internal pointer.
|
| 1416 |
+
CUTLASS_HOST_DEVICE
|
| 1417 |
+
PredicatedTileIterator &operator++() {
|
| 1418 |
+
++iterator_;
|
| 1419 |
+
return *this;
|
| 1420 |
+
}
|
| 1421 |
+
|
| 1422 |
+
/// Advances to the next tile in memory.
|
| 1423 |
+
///
|
| 1424 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 1425 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 1426 |
+
/// are lightweight and must only update the internal pointer.
|
| 1427 |
+
CUTLASS_HOST_DEVICE
|
| 1428 |
+
PredicatedTileIterator operator++(int) {
|
| 1429 |
+
PredicatedTileIterator self(*this);
|
| 1430 |
+
operator++();
|
| 1431 |
+
return self;
|
| 1432 |
+
}
|
| 1433 |
+
|
| 1434 |
+
/// Clears the predicate set efficiently
|
| 1435 |
+
CUTLASS_HOST_DEVICE
|
| 1436 |
+
void clear_mask(bool enable = true) {
|
| 1437 |
+
iterator_.clear_mask(enable);
|
| 1438 |
+
}
|
| 1439 |
+
|
| 1440 |
+
/// Clears the predicate set efficiently
|
| 1441 |
+
CUTLASS_HOST_DEVICE
|
| 1442 |
+
void enable_mask() {
|
| 1443 |
+
iterator_.enable_mask();
|
| 1444 |
+
}
|
| 1445 |
+
|
| 1446 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 1447 |
+
CUTLASS_HOST_DEVICE
|
| 1448 |
+
void set_mask(Mask const &mask) {
|
| 1449 |
+
iterator_.set_mask(mask);
|
| 1450 |
+
}
|
| 1451 |
+
|
| 1452 |
+
/// Gets the mask
|
| 1453 |
+
CUTLASS_HOST_DEVICE
|
| 1454 |
+
void get_mask(Mask &mask) {
|
| 1455 |
+
iterator_.get_mask(mask);
|
| 1456 |
+
}
|
| 1457 |
+
|
| 1458 |
+
/// Loads a fragment from memory
|
| 1459 |
+
CUTLASS_DEVICE
|
| 1460 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 1461 |
+
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 1462 |
+
}
|
| 1463 |
+
|
| 1464 |
+
/// Loads a fragment from memory
|
| 1465 |
+
CUTLASS_DEVICE
|
| 1466 |
+
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
| 1467 |
+
iterator_.load_with_byte_offset(frag, byte_offset);
|
| 1468 |
+
}
|
| 1469 |
+
|
| 1470 |
+
/// Loads a fragment from memory
|
| 1471 |
+
CUTLASS_DEVICE
|
| 1472 |
+
void load(Fragment &frag) {
|
| 1473 |
+
load_with_pointer_offset(frag, 0);
|
| 1474 |
+
}
|
| 1475 |
+
|
| 1476 |
+
/// Store a fragment to memory
|
| 1477 |
+
CUTLASS_DEVICE
|
| 1478 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 1479 |
+
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 1480 |
+
}
|
| 1481 |
+
|
| 1482 |
+
/// Store a fragment to memory
|
| 1483 |
+
CUTLASS_DEVICE
|
| 1484 |
+
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
| 1485 |
+
iterator_.store_with_byte_offset(frag, byte_offset);
|
| 1486 |
+
}
|
| 1487 |
+
|
| 1488 |
+
/// Store a fragment to memory
|
| 1489 |
+
CUTLASS_DEVICE
|
| 1490 |
+
void store(Fragment const &frag) {
|
| 1491 |
+
store_with_pointer_offset(frag, 0);
|
| 1492 |
+
}
|
| 1493 |
+
};
|
| 1494 |
+
|
| 1495 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1496 |
+
|
| 1497 |
+
/// Specialization of PredicatedTileIterator for interleaved data. It is mapped
|
| 1498 |
+
/// to the congruous layout.
|
| 1499 |
+
///
|
| 1500 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 1501 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 1502 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 1503 |
+
/// MaskedTileIteratorConcept
|
| 1504 |
+
///
|
| 1505 |
+
|
| 1506 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 1507 |
+
typename ThreadMap_, int AccessSize, int InterleavedK>
|
| 1508 |
+
class PredicatedTileIterator<Shape_, Element_,
|
| 1509 |
+
layout::ColumnMajorInterleaved<InterleavedK>,
|
| 1510 |
+
AdvanceRank, ThreadMap_, AccessSize, false> {
|
| 1511 |
+
public:
|
| 1512 |
+
static_assert(
|
| 1513 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 1514 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 1515 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 1516 |
+
|
| 1517 |
+
using Shape = Shape_;
|
| 1518 |
+
using Element = Element_;
|
| 1519 |
+
static int const kInterleavedK = InterleavedK;
|
| 1520 |
+
using Layout = layout::ColumnMajorInterleaved<kInterleavedK>;
|
| 1521 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 1522 |
+
using ThreadMap = ThreadMap_;
|
| 1523 |
+
|
| 1524 |
+
using Index = typename Layout::Index;
|
| 1525 |
+
using LongIndex = typename Layout::LongIndex;
|
| 1526 |
+
|
| 1527 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 1528 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1529 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 1530 |
+
|
| 1531 |
+
using Pointer = Element *;
|
| 1532 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 1533 |
+
|
| 1534 |
+
using UnderlyingIterator = PredicatedTileIterator<
|
| 1535 |
+
layout::PitchLinearShape<Shape::kRow * kInterleavedK,
|
| 1536 |
+
Shape::kColumn / kInterleavedK>,
|
| 1537 |
+
Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessSize>;
|
| 1538 |
+
|
| 1539 |
+
|
| 1540 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 1541 |
+
|
| 1542 |
+
/// Fragment object to be loaded or stored
|
| 1543 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
|
| 1544 |
+
ThreadMap::kElementsPerAccess>;
|
| 1545 |
+
|
| 1546 |
+
/// Predicate vector stores mask to guard accesses
|
| 1547 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 1548 |
+
|
| 1549 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 1550 |
+
class Params {
|
| 1551 |
+
private:
|
| 1552 |
+
friend PredicatedTileIterator;
|
| 1553 |
+
|
| 1554 |
+
/// Parameters object
|
| 1555 |
+
typename UnderlyingIterator::Params params_;
|
| 1556 |
+
|
| 1557 |
+
public:
|
| 1558 |
+
|
| 1559 |
+
/// Default constructor
|
| 1560 |
+
Params() = default;
|
| 1561 |
+
|
| 1562 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 1563 |
+
CUTLASS_HOST_DEVICE
|
| 1564 |
+
Params(Layout const &layout)
|
| 1565 |
+
: params_(layout::PitchLinear(layout.stride(0))) {}
|
| 1566 |
+
|
| 1567 |
+
CUTLASS_HOST_DEVICE
|
| 1568 |
+
Params(typename UnderlyingIterator::Params::Base const &base)
|
| 1569 |
+
: params_(base) {}
|
| 1570 |
+
|
| 1571 |
+
};
|
| 1572 |
+
|
| 1573 |
+
private:
|
| 1574 |
+
//
|
| 1575 |
+
// Data members
|
| 1576 |
+
//
|
| 1577 |
+
|
| 1578 |
+
/// Underlying pitch-linear tile iterator
|
| 1579 |
+
UnderlyingIterator iterator_;
|
| 1580 |
+
|
| 1581 |
+
public:
|
| 1582 |
+
|
| 1583 |
+
/// Default constructor
|
| 1584 |
+
PredicatedTileIterator() = default;
|
| 1585 |
+
|
| 1586 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 1587 |
+
/// and thread ID
|
| 1588 |
+
CUTLASS_HOST_DEVICE
|
| 1589 |
+
PredicatedTileIterator(
|
| 1590 |
+
/// Precomputed parameters object
|
| 1591 |
+
Params const ¶ms,
|
| 1592 |
+
/// Pointer to start of tensor
|
| 1593 |
+
Pointer pointer,
|
| 1594 |
+
/// Extent of tensor
|
| 1595 |
+
TensorCoord extent,
|
| 1596 |
+
/// ID of each participating thread
|
| 1597 |
+
int thread_id,
|
| 1598 |
+
/// Initial offset of threadblock
|
| 1599 |
+
TensorCoord const &threadblock_offset,
|
| 1600 |
+
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
| 1601 |
+
)
|
| 1602 |
+
: iterator_(params.params_, pointer,
|
| 1603 |
+
layout::PitchLinearCoord(extent.row() * kInterleavedK,
|
| 1604 |
+
extent.column() / kInterleavedK),
|
| 1605 |
+
thread_id,
|
| 1606 |
+
layout::PitchLinearCoord(
|
| 1607 |
+
threadblock_offset.row() * kInterleavedK,
|
| 1608 |
+
threadblock_offset.column() / kInterleavedK)) {}
|
| 1609 |
+
|
| 1610 |
+
/// Construct a PredicatedTileIterator with zero threadblock offset
|
| 1611 |
+
CUTLASS_HOST_DEVICE
|
| 1612 |
+
PredicatedTileIterator(
|
| 1613 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 1614 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 1615 |
+
TensorCoord extent, ///< Extent of tensor
|
| 1616 |
+
int thread_id ///< ID of each participating thread
|
| 1617 |
+
)
|
| 1618 |
+
: PredicatedTileIterator(params, pointer, extent, thread_id,
|
| 1619 |
+
make_Coord(0, 0)) {}
|
| 1620 |
+
|
| 1621 |
+
/// Adds a pointer offset in units of Element
|
| 1622 |
+
CUTLASS_HOST_DEVICE
|
| 1623 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1624 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 1625 |
+
}
|
| 1626 |
+
|
| 1627 |
+
/// Advances to the next tile in memory.
|
| 1628 |
+
///
|
| 1629 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1630 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1631 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1632 |
+
/// pointer.
|
| 1633 |
+
CUTLASS_HOST_DEVICE
|
| 1634 |
+
PredicatedTileIterator &operator++() {
|
| 1635 |
+
++iterator_;
|
| 1636 |
+
return *this;
|
| 1637 |
+
}
|
| 1638 |
+
|
| 1639 |
+
/// Advances to the next tile in memory.
|
| 1640 |
+
///
|
| 1641 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1642 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1643 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1644 |
+
/// pointer.
|
| 1645 |
+
CUTLASS_HOST_DEVICE
|
| 1646 |
+
PredicatedTileIterator operator++(int) {
|
| 1647 |
+
PredicatedTileIterator self(*this);
|
| 1648 |
+
operator++();
|
| 1649 |
+
return self;
|
| 1650 |
+
}
|
| 1651 |
+
|
| 1652 |
+
/// Clears the predicate set efficiently
|
| 1653 |
+
CUTLASS_HOST_DEVICE
|
| 1654 |
+
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
| 1655 |
+
|
| 1656 |
+
/// Clears the predicate set efficiently
|
| 1657 |
+
CUTLASS_HOST_DEVICE
|
| 1658 |
+
void enable_mask() { iterator_.enable_mask(); }
|
| 1659 |
+
|
| 1660 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 1661 |
+
CUTLASS_HOST_DEVICE
|
| 1662 |
+
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
| 1663 |
+
|
| 1664 |
+
/// Gets the mask
|
| 1665 |
+
CUTLASS_HOST_DEVICE
|
| 1666 |
+
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
| 1667 |
+
|
| 1668 |
+
/// Loads a fragment from memory
|
| 1669 |
+
CUTLASS_DEVICE
|
| 1670 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 1671 |
+
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 1672 |
+
}
|
| 1673 |
+
|
| 1674 |
+
/// Loads a fragment from memory
|
| 1675 |
+
CUTLASS_DEVICE
|
| 1676 |
+
void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
|
| 1677 |
+
|
| 1678 |
+
/// Store a fragment to memory
|
| 1679 |
+
CUTLASS_DEVICE
|
| 1680 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 1681 |
+
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 1682 |
+
}
|
| 1683 |
+
|
| 1684 |
+
/// Store a fragment to memory
|
| 1685 |
+
CUTLASS_DEVICE
|
| 1686 |
+
void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
|
| 1687 |
+
};
|
| 1688 |
+
|
| 1689 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1690 |
+
|
| 1691 |
+
/// Specialization of PredicatedTileIterator for interleaved-32 data. It is
|
| 1692 |
+
/// mapped to the congruous layout.
|
| 1693 |
+
///
|
| 1694 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 1695 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 1696 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 1697 |
+
/// MaskedTileIteratorConcept
|
| 1698 |
+
///
|
| 1699 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 1700 |
+
typename ThreadMap_, int AccessSize, int InterleavedK>
|
| 1701 |
+
class PredicatedTileIterator<Shape_, Element_,
|
| 1702 |
+
layout::RowMajorInterleaved<InterleavedK>,
|
| 1703 |
+
AdvanceRank, ThreadMap_, AccessSize, false> {
|
| 1704 |
+
public:
|
| 1705 |
+
static_assert(
|
| 1706 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 1707 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 1708 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 1709 |
+
|
| 1710 |
+
using Shape = Shape_;
|
| 1711 |
+
using Element = Element_;
|
| 1712 |
+
static int const kInterleavedK = InterleavedK;
|
| 1713 |
+
using Layout = layout::RowMajorInterleaved<kInterleavedK>;
|
| 1714 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 1715 |
+
using ThreadMap = ThreadMap_;
|
| 1716 |
+
|
| 1717 |
+
using Index = typename Layout::Index;
|
| 1718 |
+
using LongIndex = typename Layout::LongIndex;
|
| 1719 |
+
|
| 1720 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 1721 |
+
using TensorView = TensorView<Element, Layout>;
|
| 1722 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 1723 |
+
|
| 1724 |
+
using Pointer = Element *;
|
| 1725 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 1726 |
+
|
| 1727 |
+
using UnderlyingIterator = PredicatedTileIterator<
|
| 1728 |
+
layout::PitchLinearShape<Shape::kColumn * kInterleavedK,
|
| 1729 |
+
Shape::kRow / kInterleavedK>,
|
| 1730 |
+
Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessSize>;
|
| 1731 |
+
|
| 1732 |
+
|
| 1733 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 1734 |
+
|
| 1735 |
+
/// Fragment object to be loaded or stored
|
| 1736 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
|
| 1737 |
+
ThreadMap::kElementsPerAccess>;
|
| 1738 |
+
|
| 1739 |
+
/// Predicate vector stores mask to guard accesses
|
| 1740 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 1741 |
+
|
| 1742 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 1743 |
+
class Params {
|
| 1744 |
+
private:
|
| 1745 |
+
friend PredicatedTileIterator;
|
| 1746 |
+
|
| 1747 |
+
/// Parameters object
|
| 1748 |
+
typename UnderlyingIterator::Params params_;
|
| 1749 |
+
|
| 1750 |
+
public:
|
| 1751 |
+
|
| 1752 |
+
/// Default constructor
|
| 1753 |
+
Params() = default;
|
| 1754 |
+
|
| 1755 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 1756 |
+
CUTLASS_HOST_DEVICE
|
| 1757 |
+
Params(Layout const &layout)
|
| 1758 |
+
: params_(layout::PitchLinear(layout.stride(0))) {}
|
| 1759 |
+
|
| 1760 |
+
CUTLASS_HOST_DEVICE
|
| 1761 |
+
Params(typename UnderlyingIterator::Params::Base const &base)
|
| 1762 |
+
: params_(base) {}
|
| 1763 |
+
};
|
| 1764 |
+
|
| 1765 |
+
private:
|
| 1766 |
+
//
|
| 1767 |
+
// Data members
|
| 1768 |
+
//
|
| 1769 |
+
|
| 1770 |
+
/// Underlying pitch-linear tile iterator
|
| 1771 |
+
UnderlyingIterator iterator_;
|
| 1772 |
+
|
| 1773 |
+
public:
|
| 1774 |
+
|
| 1775 |
+
/// Default constructor
|
| 1776 |
+
PredicatedTileIterator() = default;
|
| 1777 |
+
|
| 1778 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 1779 |
+
/// and thread ID
|
| 1780 |
+
CUTLASS_HOST_DEVICE
|
| 1781 |
+
PredicatedTileIterator(
|
| 1782 |
+
/// Precomputed parameters object
|
| 1783 |
+
Params const ¶ms,
|
| 1784 |
+
/// Pointer to start of tensor
|
| 1785 |
+
Pointer pointer,
|
| 1786 |
+
/// Extent of tensor
|
| 1787 |
+
TensorCoord extent,
|
| 1788 |
+
/// ID of each participating thread
|
| 1789 |
+
int thread_id,
|
| 1790 |
+
/// Initial offset of threadblock
|
| 1791 |
+
TensorCoord const &threadblock_offset,
|
| 1792 |
+
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
| 1793 |
+
)
|
| 1794 |
+
: iterator_(params.params_, pointer,
|
| 1795 |
+
layout::PitchLinearCoord(extent.column() * kInterleavedK,
|
| 1796 |
+
extent.row() / kInterleavedK),
|
| 1797 |
+
thread_id,
|
| 1798 |
+
layout::PitchLinearCoord(
|
| 1799 |
+
threadblock_offset.column() * kInterleavedK,
|
| 1800 |
+
threadblock_offset.row() / kInterleavedK)) {}
|
| 1801 |
+
|
| 1802 |
+
/// Construct a PredicatedTileIterator with zero threadblock offset
|
| 1803 |
+
CUTLASS_HOST_DEVICE
|
| 1804 |
+
PredicatedTileIterator(
|
| 1805 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 1806 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 1807 |
+
TensorCoord extent, ///< Extent of tensor
|
| 1808 |
+
int thread_id ///< ID of each participating thread
|
| 1809 |
+
)
|
| 1810 |
+
: PredicatedTileIterator(params, pointer, extent, thread_id,
|
| 1811 |
+
make_Coord(0, 0)) {}
|
| 1812 |
+
|
| 1813 |
+
/// Adds a pointer offset in units of Element
|
| 1814 |
+
CUTLASS_HOST_DEVICE
|
| 1815 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1816 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 1817 |
+
}
|
| 1818 |
+
|
| 1819 |
+
/// Advances to the next tile in memory.
|
| 1820 |
+
///
|
| 1821 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1822 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1823 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1824 |
+
/// pointer.
|
| 1825 |
+
CUTLASS_HOST_DEVICE
|
| 1826 |
+
PredicatedTileIterator &operator++() {
|
| 1827 |
+
++iterator_;
|
| 1828 |
+
return *this;
|
| 1829 |
+
}
|
| 1830 |
+
|
| 1831 |
+
/// Advances to the next tile in memory.
|
| 1832 |
+
///
|
| 1833 |
+
/// The first time this method is called, predicates are updated, and the
|
| 1834 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 1835 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 1836 |
+
/// pointer.
|
| 1837 |
+
CUTLASS_HOST_DEVICE
|
| 1838 |
+
PredicatedTileIterator operator++(int) {
|
| 1839 |
+
PredicatedTileIterator self(*this);
|
| 1840 |
+
operator++();
|
| 1841 |
+
return self;
|
| 1842 |
+
}
|
| 1843 |
+
|
| 1844 |
+
/// Clears the predicate set efficiently
|
| 1845 |
+
CUTLASS_HOST_DEVICE
|
| 1846 |
+
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
|
| 1847 |
+
|
| 1848 |
+
/// Clears the predicate set efficiently
|
| 1849 |
+
CUTLASS_HOST_DEVICE
|
| 1850 |
+
void enable_mask() { iterator_.enable_mask(); }
|
| 1851 |
+
|
| 1852 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 1853 |
+
CUTLASS_HOST_DEVICE
|
| 1854 |
+
void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
|
| 1855 |
+
|
| 1856 |
+
/// Gets the mask
|
| 1857 |
+
CUTLASS_HOST_DEVICE
|
| 1858 |
+
void get_mask(Mask &mask) { iterator_.get_mask(mask); }
|
| 1859 |
+
|
| 1860 |
+
/// Loads a fragment from memory
|
| 1861 |
+
CUTLASS_DEVICE
|
| 1862 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 1863 |
+
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 1864 |
+
}
|
| 1865 |
+
|
| 1866 |
+
/// Loads a fragment from memory
|
| 1867 |
+
CUTLASS_DEVICE
|
| 1868 |
+
void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
|
| 1869 |
+
|
| 1870 |
+
/// Store a fragment to memory
|
| 1871 |
+
CUTLASS_DEVICE
|
| 1872 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 1873 |
+
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 1874 |
+
}
|
| 1875 |
+
|
| 1876 |
+
/// Store a fragment to memory
|
| 1877 |
+
CUTLASS_DEVICE
|
| 1878 |
+
void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
|
| 1879 |
+
};
|
| 1880 |
+
|
| 1881 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1882 |
+
|
| 1883 |
+
} // namespace threadblock
|
| 1884 |
+
} // namespace transform
|
| 1885 |
+
} // namespace cutlass
|
| 1886 |
+
|
| 1887 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h
ADDED
|
@@ -0,0 +1,787 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates implementing loading of tiles from pitch-linear rank=2 tensors.
|
| 33 |
+
|
| 34 |
+
This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile
|
| 35 |
+
first, with the objective of minimizing predicate mask updates during steady-state operation.
|
| 36 |
+
|
| 37 |
+
A precomputed "Params" object minimizes the amount of state that must be stored in registers,
|
| 38 |
+
and integer addition is used to advance the pointer through memory.
|
| 39 |
+
*/
|
| 40 |
+
|
| 41 |
+
#pragma once
|
| 42 |
+
|
| 43 |
+
#include "cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h"
|
| 44 |
+
#include "cutlass/transform/thread/transpose.h"
|
| 45 |
+
|
| 46 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace transform {
|
| 50 |
+
namespace threadblock {
|
| 51 |
+
|
| 52 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
/// PredicatedTileIterator2dThreadTile
|
| 55 |
+
///
|
| 56 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 57 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 58 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 59 |
+
/// MaskedTileIteratorConcept
|
| 60 |
+
///
|
| 61 |
+
/// Regular tile iterator using a precomputed control structure to minimize register liveness
|
| 62 |
+
/// and integer arithmetic.
|
| 63 |
+
///
|
| 64 |
+
/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed.
|
| 65 |
+
///
|
| 66 |
+
/// Base pointer and tensor extents may be specified at the time the iterator is constructed.
|
| 67 |
+
/// Subsequently, they are assumed to be immutable.
|
| 68 |
+
///
|
| 69 |
+
/// Adding a logical coordinate offset may be performed at the time the iterator is constructed.
|
| 70 |
+
/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive.
|
| 71 |
+
///
|
| 72 |
+
/// Vistitation order is intended to first visit a "residual" tile that may be partially full in
|
| 73 |
+
/// both the advance dimension and the steady-state dimension. This is assumed to be the last
|
| 74 |
+
/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to
|
| 75 |
+
/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent
|
| 76 |
+
/// accesses may be performed without updating internal predicates and are efficient in terms of
|
| 77 |
+
/// live register state and pointer arithmetic instructions.
|
| 78 |
+
///
|
| 79 |
+
/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once
|
| 80 |
+
/// outside any looping structure to minimize integer arithmetic.
|
| 81 |
+
///
|
| 82 |
+
/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing
|
| 83 |
+
/// the iterator.
|
| 84 |
+
///
|
| 85 |
+
///
|
| 86 |
+
/// Example:
|
| 87 |
+
///
|
| 88 |
+
/// An efficient pipeline structure may be constructed as follows:
|
| 89 |
+
///
|
| 90 |
+
// template <typename Iterator>
|
| 91 |
+
// __global__ void kernel(
|
| 92 |
+
// typename Iterator::Params params,
|
| 93 |
+
// typename Iterator::Element *ptr,
|
| 94 |
+
// TensorCoord extent) {
|
| 95 |
+
//
|
| 96 |
+
// typename Iterator::Fragment fragment;
|
| 97 |
+
//
|
| 98 |
+
// TensorCoord threadblock_offset(0, 0);
|
| 99 |
+
//
|
| 100 |
+
// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets);
|
| 101 |
+
//
|
| 102 |
+
//
|
| 103 |
+
// fragment = *iter; // load "residue" tile first
|
| 104 |
+
// ++iter; // advance to first "steady state" tile and update internal masks
|
| 105 |
+
//
|
| 106 |
+
//
|
| 107 |
+
// #pragma unroll
|
| 108 |
+
// for (int i = Remaining - 1; i >= 0; --i) {
|
| 109 |
+
//
|
| 110 |
+
// f(fragment);
|
| 111 |
+
//
|
| 112 |
+
// if (!i) {
|
| 113 |
+
// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs.
|
| 114 |
+
// }
|
| 115 |
+
//
|
| 116 |
+
// fragment = *iter; // load tile during "steady state" phase
|
| 117 |
+
// ++iter; // advance to next tile - lightweight due to steady-state masks
|
| 118 |
+
// }
|
| 119 |
+
// }
|
| 120 |
+
//
|
| 121 |
+
// void host(TensorView<Element, 2, layout::PitchLinear> view) {
|
| 122 |
+
//
|
| 123 |
+
// using Iterator = transform::threadblock::PredicatedTileIterator2dThreadTile;
|
| 124 |
+
//
|
| 125 |
+
// typename Iterator::Params params(view.layout());
|
| 126 |
+
//
|
| 127 |
+
// kernel<Iterator>(params, view.data());
|
| 128 |
+
// }
|
| 129 |
+
///
|
| 130 |
+
///
|
| 131 |
+
template <
|
| 132 |
+
typename Shape,
|
| 133 |
+
typename Element,
|
| 134 |
+
typename Layout,
|
| 135 |
+
int AdvanceRank,
|
| 136 |
+
typename ThreadMap,
|
| 137 |
+
bool Transpose = false
|
| 138 |
+
>
|
| 139 |
+
class PredicatedTileIterator2dThreadTile;
|
| 140 |
+
|
| 141 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 142 |
+
|
| 143 |
+
/// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data.
|
| 144 |
+
///
|
| 145 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 146 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 147 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 148 |
+
/// MaskedTileIteratorConcept
|
| 149 |
+
///
|
| 150 |
+
template <typename Shape_, typename Element_, int AdvanceRank, typename ThreadMap_, bool Transpose_>
|
| 151 |
+
class PredicatedTileIterator2dThreadTile<Shape_, Element_, layout::PitchLinear, AdvanceRank, ThreadMap_, Transpose_> {
|
| 152 |
+
public:
|
| 153 |
+
static_assert(
|
| 154 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 155 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 156 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 157 |
+
|
| 158 |
+
using Shape = Shape_;
|
| 159 |
+
using Element = Element_;
|
| 160 |
+
using Layout = layout::PitchLinear;
|
| 161 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 162 |
+
using ThreadMap = ThreadMap_;
|
| 163 |
+
|
| 164 |
+
using Index = typename Layout::Index;
|
| 165 |
+
using LongIndex = typename Layout::LongIndex;
|
| 166 |
+
|
| 167 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 168 |
+
using TensorView = TensorView<Element, Layout>;
|
| 169 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 170 |
+
|
| 171 |
+
using Pointer = Element *;
|
| 172 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 173 |
+
|
| 174 |
+
/// Type used for internal memory accesses
|
| 175 |
+
/// extra set of parenthesis is needed for VS compiler
|
| 176 |
+
struct alignas((ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value /
|
| 177 |
+
8)) AccessType {
|
| 178 |
+
|
| 179 |
+
Array<Element, ThreadMap::kElementsPerAccess> storage;
|
| 180 |
+
|
| 181 |
+
static int const kElements = ThreadMap::kElementsPerAccess;
|
| 182 |
+
};
|
| 183 |
+
|
| 184 |
+
/// Optionally this fragment can be 4x4 transposed
|
| 185 |
+
using Transform = thread::Transpose< ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount , layout::PitchLinearShape<4,4>, Element>;
|
| 186 |
+
static bool const transpose = Transpose_;
|
| 187 |
+
|
| 188 |
+
/// Underlying iterator to compute the addresses
|
| 189 |
+
using TileAccessIterator =
|
| 190 |
+
PredicatedTileAccessIterator2dThreadTile<Shape, Element, Layout, kAdvanceRank,
|
| 191 |
+
ThreadMap, AccessType>;
|
| 192 |
+
|
| 193 |
+
/// Fragment object to be loaded or stored
|
| 194 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
|
| 195 |
+
ThreadMap::ThreadAccessShape::kCount>;
|
| 196 |
+
|
| 197 |
+
/// Predicate vector stores mask to guard accesses
|
| 198 |
+
using Mask = typename TileAccessIterator::Mask;
|
| 199 |
+
|
| 200 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 201 |
+
class Params {
|
| 202 |
+
public:
|
| 203 |
+
using Base = typename TileAccessIterator::Params::Base;
|
| 204 |
+
|
| 205 |
+
friend PredicatedTileIterator2dThreadTile;
|
| 206 |
+
|
| 207 |
+
private:
|
| 208 |
+
/// Parameters object
|
| 209 |
+
typename TileAccessIterator::Params params_;
|
| 210 |
+
|
| 211 |
+
public:
|
| 212 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 213 |
+
CUTLASS_HOST_DEVICE
|
| 214 |
+
Params(Layout const &layout) : params_(layout) { }
|
| 215 |
+
|
| 216 |
+
CUTLASS_HOST_DEVICE
|
| 217 |
+
Params() { }
|
| 218 |
+
|
| 219 |
+
CUTLASS_HOST_DEVICE
|
| 220 |
+
Params(Base const &base)
|
| 221 |
+
: params_(base) {}
|
| 222 |
+
};
|
| 223 |
+
|
| 224 |
+
private:
|
| 225 |
+
/// Internal pointer type permits fast address arithmetic
|
| 226 |
+
using BytePointer = char *;
|
| 227 |
+
|
| 228 |
+
private:
|
| 229 |
+
//
|
| 230 |
+
// Data members
|
| 231 |
+
//
|
| 232 |
+
|
| 233 |
+
/// Data member to the tile access iterator
|
| 234 |
+
TileAccessIterator address_iterator_;
|
| 235 |
+
|
| 236 |
+
public:
|
| 237 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 238 |
+
/// and thread ID
|
| 239 |
+
CUTLASS_HOST_DEVICE
|
| 240 |
+
PredicatedTileIterator2dThreadTile(
|
| 241 |
+
/// Precomputed parameters object
|
| 242 |
+
Params const ¶ms,
|
| 243 |
+
/// Pointer to start of tensor
|
| 244 |
+
Pointer pointer,
|
| 245 |
+
/// Extent of tensor
|
| 246 |
+
TensorCoord extent,
|
| 247 |
+
/// ID of each participating thread
|
| 248 |
+
int thread_id,
|
| 249 |
+
/// Initial offset of threadblock
|
| 250 |
+
TensorCoord const &threadblock_offset,
|
| 251 |
+
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
| 252 |
+
)
|
| 253 |
+
: address_iterator_(params.params_, pointer, extent, thread_id,
|
| 254 |
+
threadblock_offset) {}
|
| 255 |
+
|
| 256 |
+
/// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset
|
| 257 |
+
CUTLASS_HOST_DEVICE
|
| 258 |
+
PredicatedTileIterator2dThreadTile(
|
| 259 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 260 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 261 |
+
TensorCoord extent, ///< Extent of tensor
|
| 262 |
+
int thread_id ///< ID of each participating thread
|
| 263 |
+
)
|
| 264 |
+
: PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id,
|
| 265 |
+
make_Coord(0, 0)) {}
|
| 266 |
+
|
| 267 |
+
/// Adds a pointer offset in units of Element
|
| 268 |
+
CUTLASS_HOST_DEVICE
|
| 269 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 270 |
+
address_iterator_.add_pointer_offset(pointer_offset);
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
/// Advances to the next tile in memory.
|
| 274 |
+
///
|
| 275 |
+
/// The first time this method is called, predicates are updated, and the
|
| 276 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 277 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 278 |
+
/// pointer.
|
| 279 |
+
CUTLASS_HOST_DEVICE
|
| 280 |
+
PredicatedTileIterator2dThreadTile &operator++() {
|
| 281 |
+
if (kAdvanceRank)
|
| 282 |
+
address_iterator_.add_tile_offset({0, 1});
|
| 283 |
+
else
|
| 284 |
+
address_iterator_.add_tile_offset({1, 0});
|
| 285 |
+
|
| 286 |
+
return *this;
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
/// Advances to the next tile in memory.
|
| 290 |
+
///
|
| 291 |
+
/// The first time this method is called, predicates are updated, and the
|
| 292 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 293 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 294 |
+
/// pointer.
|
| 295 |
+
CUTLASS_HOST_DEVICE
|
| 296 |
+
PredicatedTileIterator2dThreadTile operator++(int) {
|
| 297 |
+
PredicatedTileIterator2dThreadTile self(*this);
|
| 298 |
+
operator++();
|
| 299 |
+
return self;
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
/// Clears the predicate set efficiently
|
| 303 |
+
CUTLASS_HOST_DEVICE
|
| 304 |
+
void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
|
| 305 |
+
|
| 306 |
+
/// Clears the predicate set efficiently
|
| 307 |
+
CUTLASS_HOST_DEVICE
|
| 308 |
+
void enable_mask() { address_iterator_.enable_mask(); }
|
| 309 |
+
|
| 310 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 311 |
+
CUTLASS_HOST_DEVICE
|
| 312 |
+
void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); }
|
| 313 |
+
|
| 314 |
+
/// Gets the mask
|
| 315 |
+
CUTLASS_HOST_DEVICE
|
| 316 |
+
void get_mask(Mask &mask) { address_iterator_.get_mask(mask); }
|
| 317 |
+
|
| 318 |
+
/// Loads a fragment from memory
|
| 319 |
+
CUTLASS_DEVICE
|
| 320 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 321 |
+
|
| 322 |
+
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
| 323 |
+
|
| 324 |
+
CUTLASS_PRAGMA_UNROLL
|
| 325 |
+
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
| 326 |
+
CUTLASS_PRAGMA_UNROLL
|
| 327 |
+
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
| 328 |
+
CUTLASS_PRAGMA_UNROLL
|
| 329 |
+
for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++){
|
| 330 |
+
|
| 331 |
+
int access_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + \
|
| 332 |
+
s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided;
|
| 333 |
+
|
| 334 |
+
address_iterator_.set_iteration_index(access_idx);
|
| 335 |
+
if (address_iterator_.valid()) {
|
| 336 |
+
|
| 337 |
+
frag_ptr[access_idx] =
|
| 338 |
+
*(address_iterator_.get() + pointer_offset);
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
++address_iterator_;
|
| 342 |
+
}
|
| 343 |
+
}
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
if (transpose) {
|
| 347 |
+
Transform t;
|
| 348 |
+
t.transform(frag, frag);
|
| 349 |
+
}
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
/// Loads a fragment from memory
|
| 353 |
+
CUTLASS_DEVICE
|
| 354 |
+
void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
|
| 355 |
+
|
| 356 |
+
/// Store a fragment to memory
|
| 357 |
+
CUTLASS_DEVICE
|
| 358 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 359 |
+
|
| 360 |
+
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
| 361 |
+
|
| 362 |
+
CUTLASS_PRAGMA_UNROLL
|
| 363 |
+
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
| 364 |
+
CUTLASS_PRAGMA_UNROLL
|
| 365 |
+
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
| 366 |
+
CUTLASS_PRAGMA_UNROLL
|
| 367 |
+
for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++){
|
| 368 |
+
|
| 369 |
+
int access_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + \
|
| 370 |
+
s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided;
|
| 371 |
+
|
| 372 |
+
address_iterator_.set_iteration_index(access_idx);
|
| 373 |
+
if (address_iterator_.valid()) {
|
| 374 |
+
*(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx];
|
| 375 |
+
}
|
| 376 |
+
++address_iterator_;
|
| 377 |
+
}
|
| 378 |
+
}
|
| 379 |
+
}
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
/// Store a fragment to memory
|
| 383 |
+
CUTLASS_DEVICE
|
| 384 |
+
void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
|
| 385 |
+
};
|
| 386 |
+
|
| 387 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 388 |
+
|
| 389 |
+
/// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data.
|
| 390 |
+
///
|
| 391 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 392 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 393 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 394 |
+
/// MaskedTileIteratorConcept
|
| 395 |
+
///
|
| 396 |
+
template <
|
| 397 |
+
typename Shape_,
|
| 398 |
+
typename Element_,
|
| 399 |
+
int AdvanceRank,
|
| 400 |
+
typename ThreadMap_,
|
| 401 |
+
bool Transpose_
|
| 402 |
+
>
|
| 403 |
+
class PredicatedTileIterator2dThreadTile<Shape_, Element_, layout::ColumnMajor, AdvanceRank, ThreadMap_, Transpose_> {
|
| 404 |
+
public:
|
| 405 |
+
|
| 406 |
+
static_assert(AdvanceRank == 0 || AdvanceRank == 1,
|
| 407 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 408 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 409 |
+
|
| 410 |
+
using Shape = Shape_;
|
| 411 |
+
using Element = Element_;
|
| 412 |
+
using Layout = layout::ColumnMajor;
|
| 413 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 414 |
+
using ThreadMap = ThreadMap_;
|
| 415 |
+
static bool const Transpose = Transpose_;
|
| 416 |
+
|
| 417 |
+
using Index = typename Layout::Index;
|
| 418 |
+
using LongIndex = typename Layout::LongIndex;
|
| 419 |
+
|
| 420 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 421 |
+
using TensorView = TensorView<Element, Layout>;
|
| 422 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 423 |
+
|
| 424 |
+
using Pointer = Element *;
|
| 425 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 426 |
+
|
| 427 |
+
using UnderlyingIterator = PredicatedTileIterator2dThreadTile<
|
| 428 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
|
| 429 |
+
Element,
|
| 430 |
+
layout::PitchLinear,
|
| 431 |
+
(kAdvanceRank == 0 ? 0 : 1),
|
| 432 |
+
ThreadMap,
|
| 433 |
+
Transpose
|
| 434 |
+
>;
|
| 435 |
+
|
| 436 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 437 |
+
|
| 438 |
+
/// Fragment object to be loaded or stored
|
| 439 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount>;
|
| 440 |
+
|
| 441 |
+
/// Predicate vector stores mask to guard accesses
|
| 442 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 443 |
+
|
| 444 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 445 |
+
class Params {
|
| 446 |
+
private:
|
| 447 |
+
|
| 448 |
+
friend PredicatedTileIterator2dThreadTile;
|
| 449 |
+
|
| 450 |
+
/// Parameters object
|
| 451 |
+
typename UnderlyingIterator::Params params_;
|
| 452 |
+
|
| 453 |
+
public:
|
| 454 |
+
|
| 455 |
+
CUTLASS_HOST_DEVICE
|
| 456 |
+
Params() { }
|
| 457 |
+
|
| 458 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 459 |
+
CUTLASS_HOST_DEVICE
|
| 460 |
+
Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {}
|
| 461 |
+
|
| 462 |
+
CUTLASS_HOST_DEVICE
|
| 463 |
+
Params(typename UnderlyingIterator::Params::Base const &base)
|
| 464 |
+
: params_(base) {}
|
| 465 |
+
};
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
private:
|
| 469 |
+
|
| 470 |
+
//
|
| 471 |
+
// Data members
|
| 472 |
+
//
|
| 473 |
+
|
| 474 |
+
/// Underlying pitch-linear tile iterator
|
| 475 |
+
UnderlyingIterator iterator_;
|
| 476 |
+
|
| 477 |
+
public:
|
| 478 |
+
|
| 479 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
|
| 480 |
+
CUTLASS_HOST_DEVICE
|
| 481 |
+
PredicatedTileIterator2dThreadTile(
|
| 482 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 483 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 484 |
+
TensorCoord extent, ///< Extent of tensor
|
| 485 |
+
int thread_id, ///< ID of each participating thread
|
| 486 |
+
TensorCoord const &threadblock_offset, ///< Initial offset of threadblock
|
| 487 |
+
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
| 488 |
+
):
|
| 489 |
+
iterator_(
|
| 490 |
+
params.params_,
|
| 491 |
+
pointer,
|
| 492 |
+
layout::PitchLinearCoord(extent.row(), extent.column()),
|
| 493 |
+
thread_id,
|
| 494 |
+
layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column())
|
| 495 |
+
) { }
|
| 496 |
+
|
| 497 |
+
/// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset
|
| 498 |
+
CUTLASS_HOST_DEVICE
|
| 499 |
+
PredicatedTileIterator2dThreadTile(
|
| 500 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 501 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 502 |
+
TensorCoord extent, ///< Extent of tensor
|
| 503 |
+
int thread_id ///< ID of each participating thread
|
| 504 |
+
): PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
|
| 505 |
+
|
| 506 |
+
/// Adds a pointer offset in units of Element
|
| 507 |
+
CUTLASS_HOST_DEVICE
|
| 508 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 509 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
/// Advances to the next tile in memory.
|
| 513 |
+
///
|
| 514 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 515 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 516 |
+
/// are lightweight and must only update the internal pointer.
|
| 517 |
+
CUTLASS_HOST_DEVICE
|
| 518 |
+
PredicatedTileIterator2dThreadTile &operator++() {
|
| 519 |
+
++iterator_;
|
| 520 |
+
return *this;
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
/// Advances to the next tile in memory.
|
| 524 |
+
///
|
| 525 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 526 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 527 |
+
/// are lightweight and must only update the internal pointer.
|
| 528 |
+
CUTLASS_HOST_DEVICE
|
| 529 |
+
PredicatedTileIterator2dThreadTile operator++(int) {
|
| 530 |
+
PredicatedTileIterator2dThreadTile self(*this);
|
| 531 |
+
operator++();
|
| 532 |
+
return self;
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
/// Clears the predicate set efficiently
|
| 536 |
+
CUTLASS_HOST_DEVICE
|
| 537 |
+
void clear_mask(bool enable = true) {
|
| 538 |
+
iterator_.clear_mask(enable);
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
/// Clears the predicate set efficiently
|
| 542 |
+
CUTLASS_HOST_DEVICE
|
| 543 |
+
void enable_mask() {
|
| 544 |
+
iterator_.enable_mask();
|
| 545 |
+
}
|
| 546 |
+
|
| 547 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 548 |
+
CUTLASS_HOST_DEVICE
|
| 549 |
+
void set_mask(Mask const &mask) {
|
| 550 |
+
iterator_.set_mask(mask);
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
/// Gets the mask
|
| 554 |
+
CUTLASS_HOST_DEVICE
|
| 555 |
+
void get_mask(Mask &mask) {
|
| 556 |
+
iterator_.get_mask(mask);
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
/// Loads a fragment from memory
|
| 560 |
+
CUTLASS_DEVICE
|
| 561 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 562 |
+
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
/// Loads a fragment from memory
|
| 566 |
+
CUTLASS_DEVICE
|
| 567 |
+
void load(Fragment &frag) {
|
| 568 |
+
load_with_pointer_offset(frag, 0);
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
/// Store a fragment to memory
|
| 572 |
+
CUTLASS_DEVICE
|
| 573 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 574 |
+
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
/// Store a fragment to memory
|
| 578 |
+
CUTLASS_DEVICE
|
| 579 |
+
void store(Fragment const &frag) {
|
| 580 |
+
store_with_pointer_offset(frag, 0);
|
| 581 |
+
}
|
| 582 |
+
};
|
| 583 |
+
|
| 584 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 585 |
+
|
| 586 |
+
/// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data.
|
| 587 |
+
///
|
| 588 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 589 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 590 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 591 |
+
/// MaskedTileIteratorConcept
|
| 592 |
+
///
|
| 593 |
+
template <
|
| 594 |
+
typename Shape_,
|
| 595 |
+
typename Element_,
|
| 596 |
+
int AdvanceRank,
|
| 597 |
+
typename ThreadMap_,
|
| 598 |
+
bool Transpose_
|
| 599 |
+
>
|
| 600 |
+
class PredicatedTileIterator2dThreadTile<Shape_, Element_, layout::RowMajor, AdvanceRank, ThreadMap_, Transpose_> {
|
| 601 |
+
public:
|
| 602 |
+
|
| 603 |
+
static_assert(AdvanceRank == 0 || AdvanceRank == 1,
|
| 604 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 605 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 606 |
+
|
| 607 |
+
using Shape = Shape_;
|
| 608 |
+
using Element = Element_;
|
| 609 |
+
using Layout = layout::RowMajor;
|
| 610 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 611 |
+
using ThreadMap = ThreadMap_;
|
| 612 |
+
static bool const Transpose = Transpose_;
|
| 613 |
+
|
| 614 |
+
using Index = typename Layout::Index;
|
| 615 |
+
using LongIndex = typename Layout::LongIndex;
|
| 616 |
+
|
| 617 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 618 |
+
using TensorView = TensorView<Element, Layout>;
|
| 619 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 620 |
+
|
| 621 |
+
using Pointer = Element *;
|
| 622 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 623 |
+
|
| 624 |
+
using UnderlyingIterator = PredicatedTileIterator2dThreadTile<
|
| 625 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
|
| 626 |
+
Element,
|
| 627 |
+
layout::PitchLinear,
|
| 628 |
+
(kAdvanceRank == 0 ? 1 : 0),
|
| 629 |
+
ThreadMap,
|
| 630 |
+
Transpose
|
| 631 |
+
>;
|
| 632 |
+
|
| 633 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 634 |
+
|
| 635 |
+
/// Fragment object to be loaded or stored
|
| 636 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount>;
|
| 637 |
+
|
| 638 |
+
/// Predicate vector stores mask to guard accesses
|
| 639 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 640 |
+
|
| 641 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 642 |
+
class Params {
|
| 643 |
+
private:
|
| 644 |
+
|
| 645 |
+
friend PredicatedTileIterator2dThreadTile;
|
| 646 |
+
|
| 647 |
+
/// Parameters object
|
| 648 |
+
typename UnderlyingIterator::Params params_;
|
| 649 |
+
|
| 650 |
+
public:
|
| 651 |
+
|
| 652 |
+
CUTLASS_HOST_DEVICE
|
| 653 |
+
Params() { }
|
| 654 |
+
|
| 655 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 656 |
+
CUTLASS_HOST_DEVICE
|
| 657 |
+
Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { }
|
| 658 |
+
|
| 659 |
+
CUTLASS_HOST_DEVICE
|
| 660 |
+
Params(typename UnderlyingIterator::Params::Base const &base)
|
| 661 |
+
: params_(base) {}
|
| 662 |
+
};
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
private:
|
| 666 |
+
|
| 667 |
+
//
|
| 668 |
+
// Data members
|
| 669 |
+
//
|
| 670 |
+
|
| 671 |
+
/// Underlying pitch-linear tile iterator
|
| 672 |
+
UnderlyingIterator iterator_;
|
| 673 |
+
|
| 674 |
+
public:
|
| 675 |
+
|
| 676 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
|
| 677 |
+
CUTLASS_HOST_DEVICE
|
| 678 |
+
PredicatedTileIterator2dThreadTile(
|
| 679 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 680 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 681 |
+
TensorCoord extent, ///< Extent of tensor
|
| 682 |
+
int thread_id, ///< ID of each participating thread
|
| 683 |
+
TensorCoord const &threadblock_offset, ///< Initial offset of threadblock
|
| 684 |
+
int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization
|
| 685 |
+
):
|
| 686 |
+
iterator_(
|
| 687 |
+
params.params_,
|
| 688 |
+
pointer,
|
| 689 |
+
layout::PitchLinearCoord(extent.column(), extent.row()),
|
| 690 |
+
thread_id,
|
| 691 |
+
layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row())
|
| 692 |
+
) { }
|
| 693 |
+
|
| 694 |
+
/// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset
|
| 695 |
+
CUTLASS_HOST_DEVICE
|
| 696 |
+
PredicatedTileIterator2dThreadTile(
|
| 697 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 698 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 699 |
+
TensorCoord extent, ///< Extent of tensor
|
| 700 |
+
int thread_id ///< ID of each participating thread
|
| 701 |
+
): PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
|
| 702 |
+
|
| 703 |
+
/// Adds a pointer offset in units of Element
|
| 704 |
+
CUTLASS_HOST_DEVICE
|
| 705 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 706 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 707 |
+
}
|
| 708 |
+
|
| 709 |
+
/// Advances to the next tile in memory.
|
| 710 |
+
///
|
| 711 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 712 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 713 |
+
/// are lightweight and must only update the internal pointer.
|
| 714 |
+
CUTLASS_HOST_DEVICE
|
| 715 |
+
PredicatedTileIterator2dThreadTile &operator++() {
|
| 716 |
+
++iterator_;
|
| 717 |
+
return *this;
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
/// Advances to the next tile in memory.
|
| 721 |
+
///
|
| 722 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 723 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 724 |
+
/// are lightweight and must only update the internal pointer.
|
| 725 |
+
CUTLASS_HOST_DEVICE
|
| 726 |
+
PredicatedTileIterator2dThreadTile operator++(int) {
|
| 727 |
+
PredicatedTileIterator2dThreadTile self(*this);
|
| 728 |
+
operator++();
|
| 729 |
+
return self;
|
| 730 |
+
}
|
| 731 |
+
|
| 732 |
+
/// Clears the predicate set efficiently
|
| 733 |
+
CUTLASS_HOST_DEVICE
|
| 734 |
+
void clear_mask(bool enable = true) {
|
| 735 |
+
iterator_.clear_mask(enable);
|
| 736 |
+
}
|
| 737 |
+
|
| 738 |
+
/// Clears the predicate set efficiently
|
| 739 |
+
CUTLASS_HOST_DEVICE
|
| 740 |
+
void enable_mask() {
|
| 741 |
+
iterator_.enable_mask();
|
| 742 |
+
}
|
| 743 |
+
|
| 744 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 745 |
+
CUTLASS_HOST_DEVICE
|
| 746 |
+
void set_mask(Mask const &mask) {
|
| 747 |
+
iterator_.set_mask(mask);
|
| 748 |
+
}
|
| 749 |
+
|
| 750 |
+
/// Gets the mask
|
| 751 |
+
CUTLASS_HOST_DEVICE
|
| 752 |
+
void get_mask(Mask &mask) {
|
| 753 |
+
iterator_.get_mask(mask);
|
| 754 |
+
}
|
| 755 |
+
|
| 756 |
+
/// Loads a fragment from memory
|
| 757 |
+
CUTLASS_DEVICE
|
| 758 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 759 |
+
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 760 |
+
}
|
| 761 |
+
|
| 762 |
+
/// Loads a fragment from memory
|
| 763 |
+
CUTLASS_DEVICE
|
| 764 |
+
void load(Fragment &frag) {
|
| 765 |
+
load_with_pointer_offset(frag, 0);
|
| 766 |
+
}
|
| 767 |
+
|
| 768 |
+
/// Store a fragment to memory
|
| 769 |
+
CUTLASS_DEVICE
|
| 770 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 771 |
+
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 772 |
+
}
|
| 773 |
+
|
| 774 |
+
/// Store a fragment to memory
|
| 775 |
+
CUTLASS_DEVICE
|
| 776 |
+
void store(Fragment const &frag) {
|
| 777 |
+
store_with_pointer_offset(frag, 0);
|
| 778 |
+
}
|
| 779 |
+
};
|
| 780 |
+
|
| 781 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 782 |
+
|
| 783 |
+
} // namespace threadblock
|
| 784 |
+
} // namespace transform
|
| 785 |
+
} // namespace cutlass
|
| 786 |
+
|
| 787 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h
ADDED
|
@@ -0,0 +1,818 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates implementing loading of tiles from pitch-linear rank=2 tensors.
|
| 33 |
+
|
| 34 |
+
This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile
|
| 35 |
+
first, with the objective of minimizing predicate mask updates during steady-state operation.
|
| 36 |
+
|
| 37 |
+
A precomputed "Params" object minimizes the amount of state that must be stored in registers,
|
| 38 |
+
and integer addition is used to advance the pointer through memory.
|
| 39 |
+
*/
|
| 40 |
+
|
| 41 |
+
#pragma once
|
| 42 |
+
|
| 43 |
+
#include "cutlass/arch/memory.h"
|
| 44 |
+
#include "cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h"
|
| 45 |
+
|
| 46 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
namespace cutlass {
|
| 49 |
+
namespace transform {
|
| 50 |
+
namespace threadblock {
|
| 51 |
+
|
| 52 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
/// PredicatedTileIteratorTriangularMatrix
|
| 55 |
+
///
|
| 56 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 57 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 58 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 59 |
+
/// MaskedTileIteratorConcept
|
| 60 |
+
///
|
| 61 |
+
/// Regular tile iterator using a precomputed control structure to minimize register liveness
|
| 62 |
+
/// and integer arithmetic.
|
| 63 |
+
///
|
| 64 |
+
/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed.
|
| 65 |
+
///
|
| 66 |
+
/// Base pointer and tensor extents may be specified at the time the iterator is constructed.
|
| 67 |
+
/// Subsequently, they are assumed to be immutable.
|
| 68 |
+
///
|
| 69 |
+
/// Adding a logical coordinate offset may be performed at the time the iterator is constructed.
|
| 70 |
+
/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive.
|
| 71 |
+
///
|
| 72 |
+
/// Vistitation order is intended to first visit a "residual" tile that may be partially full in
|
| 73 |
+
/// both the advance dimension and the steady-state dimension. This is assumed to be the last
|
| 74 |
+
/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to
|
| 75 |
+
/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent
|
| 76 |
+
/// accesses may be performed without updating internal predicates and are efficient in terms of
|
| 77 |
+
/// live register state and pointer arithmetic instructions.
|
| 78 |
+
///
|
| 79 |
+
/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once
|
| 80 |
+
/// outside any looping structure to minimize integer arithmetic.
|
| 81 |
+
///
|
| 82 |
+
/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing
|
| 83 |
+
/// the iterator.
|
| 84 |
+
///
|
| 85 |
+
///
|
| 86 |
+
/// Example:
|
| 87 |
+
///
|
| 88 |
+
/// An efficient pipeline structure may be constructed as follows:
|
| 89 |
+
///
|
| 90 |
+
// template <typename Iterator>
|
| 91 |
+
// __global__ void kernel(
|
| 92 |
+
// typename Iterator::Params params,
|
| 93 |
+
// typename Iterator::Element *ptr,
|
| 94 |
+
// TensorCoord extent) {
|
| 95 |
+
//
|
| 96 |
+
// typename Iterator::Fragment fragment;
|
| 97 |
+
//
|
| 98 |
+
// TensorCoord threadblock_offset(0, 0);
|
| 99 |
+
//
|
| 100 |
+
// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets);
|
| 101 |
+
//
|
| 102 |
+
//
|
| 103 |
+
// fragment = *iter; // load "residue" tile first
|
| 104 |
+
// ++iter; // advance to first "steady state" tile and update internal masks
|
| 105 |
+
//
|
| 106 |
+
//
|
| 107 |
+
// #pragma unroll
|
| 108 |
+
// for (int i = Remaining - 1; i >= 0; --i) {
|
| 109 |
+
//
|
| 110 |
+
// f(fragment);
|
| 111 |
+
//
|
| 112 |
+
// if (!i) {
|
| 113 |
+
// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs.
|
| 114 |
+
// }
|
| 115 |
+
//
|
| 116 |
+
// fragment = *iter; // load tile during "steady state" phase
|
| 117 |
+
// ++iter; // advance to next tile - lightweight due to steady-state masks
|
| 118 |
+
// }
|
| 119 |
+
// }
|
| 120 |
+
//
|
| 121 |
+
// void host(TensorView<Element, 2, layout::PitchLinear> view) {
|
| 122 |
+
//
|
| 123 |
+
// using Iterator = transform::threadblock::PredicatedTileIteratorTriangularMatrix;
|
| 124 |
+
//
|
| 125 |
+
// typename Iterator::Params params(view.layout());
|
| 126 |
+
//
|
| 127 |
+
// kernel<Iterator>(params, view.data());
|
| 128 |
+
// }
|
| 129 |
+
///
|
| 130 |
+
///
|
| 131 |
+
template <
|
| 132 |
+
typename Shape,
|
| 133 |
+
typename Element,
|
| 134 |
+
typename Layout,
|
| 135 |
+
int AdvanceRank,
|
| 136 |
+
typename ThreadMap,
|
| 137 |
+
SideMode kSideMode,
|
| 138 |
+
FillMode kFillMode,
|
| 139 |
+
DiagType kDiagType,
|
| 140 |
+
int AccessSize = ThreadMap::kElementsPerAccess
|
| 141 |
+
>
|
| 142 |
+
class PredicatedTileIteratorTriangularMatrix;
|
| 143 |
+
|
| 144 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 145 |
+
|
| 146 |
+
/// Specialization of PredicatedTileIteratorTriangularMatrix for pitch-linear data.
|
| 147 |
+
///
|
| 148 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 149 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 150 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 151 |
+
/// MaskedTileIteratorConcept
|
| 152 |
+
///
|
| 153 |
+
template <typename Shape_, typename Element_, int AdvanceRank, typename ThreadMap_,
|
| 154 |
+
SideMode kSideMode, FillMode kFillMode, DiagType kDiagType,
|
| 155 |
+
int AccessSize>
|
| 156 |
+
class PredicatedTileIteratorTriangularMatrix<Shape_, Element_, layout::PitchLinear, AdvanceRank, ThreadMap_,
|
| 157 |
+
kSideMode, kFillMode, kDiagType,
|
| 158 |
+
AccessSize> {
|
| 159 |
+
public:
|
| 160 |
+
static_assert(
|
| 161 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 162 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 163 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 164 |
+
|
| 165 |
+
using Shape = Shape_;
|
| 166 |
+
using Element = Element_;
|
| 167 |
+
using Layout = layout::PitchLinear;
|
| 168 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 169 |
+
using ThreadMap = ThreadMap_;
|
| 170 |
+
|
| 171 |
+
using Index = typename Layout::Index;
|
| 172 |
+
using LongIndex = typename Layout::LongIndex;
|
| 173 |
+
|
| 174 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 175 |
+
using TensorView = TensorView<Element, Layout>;
|
| 176 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 177 |
+
|
| 178 |
+
using Pointer = Element *;
|
| 179 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 180 |
+
|
| 181 |
+
/// Type used for internal memory accesses
|
| 182 |
+
using AccessType = AlignedArray<Element, AccessSize, (AccessSize * sizeof_bits<Element>::value / 8)>;
|
| 183 |
+
|
| 184 |
+
/// Underlying iterator to compute the addresses
|
| 185 |
+
using TileAccessIterator =
|
| 186 |
+
PredicatedTileAccessIteratorTriangularMatrix<Shape, Element, Layout, kAdvanceRank,
|
| 187 |
+
ThreadMap, kSideMode, kFillMode, kDiagType, AccessType>;
|
| 188 |
+
|
| 189 |
+
static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
|
| 190 |
+
|
| 191 |
+
/// Fragment object to be loaded or stored
|
| 192 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
|
| 193 |
+
ThreadMap::kElementsPerAccess>;
|
| 194 |
+
|
| 195 |
+
/// Predicate vector stores mask to guard accesses
|
| 196 |
+
using Mask = typename TileAccessIterator::Mask;
|
| 197 |
+
|
| 198 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 199 |
+
class Params {
|
| 200 |
+
public:
|
| 201 |
+
friend PredicatedTileIteratorTriangularMatrix;
|
| 202 |
+
|
| 203 |
+
private:
|
| 204 |
+
/// Parameters object
|
| 205 |
+
typename TileAccessIterator::Params params_;
|
| 206 |
+
|
| 207 |
+
public:
|
| 208 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 209 |
+
CUTLASS_HOST_DEVICE
|
| 210 |
+
Params(Layout const &layout) : params_(layout) { }
|
| 211 |
+
|
| 212 |
+
CUTLASS_HOST_DEVICE
|
| 213 |
+
Params() { }
|
| 214 |
+
};
|
| 215 |
+
|
| 216 |
+
private:
|
| 217 |
+
/// Internal pointer type permits fast address arithmetic
|
| 218 |
+
using BytePointer = char *;
|
| 219 |
+
|
| 220 |
+
private:
|
| 221 |
+
//
|
| 222 |
+
// Data members
|
| 223 |
+
//
|
| 224 |
+
|
| 225 |
+
/// Data member to the tile access iterator
|
| 226 |
+
TileAccessIterator address_iterator_;
|
| 227 |
+
|
| 228 |
+
public:
|
| 229 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 230 |
+
/// and thread ID
|
| 231 |
+
CUTLASS_HOST_DEVICE
|
| 232 |
+
PredicatedTileIteratorTriangularMatrix(
|
| 233 |
+
/// Precomputed parameters object
|
| 234 |
+
Params const ¶ms,
|
| 235 |
+
/// Pointer to start of tensor
|
| 236 |
+
Pointer pointer,
|
| 237 |
+
/// Extent of tensor
|
| 238 |
+
TensorCoord extent,
|
| 239 |
+
/// ID of each participating thread
|
| 240 |
+
int thread_id,
|
| 241 |
+
/// Initial offset of threadblock
|
| 242 |
+
TensorCoord const &threadblock_offset)
|
| 243 |
+
: address_iterator_(params.params_, pointer, extent, thread_id,
|
| 244 |
+
threadblock_offset) {}
|
| 245 |
+
|
| 246 |
+
/// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset
|
| 247 |
+
CUTLASS_HOST_DEVICE
|
| 248 |
+
PredicatedTileIteratorTriangularMatrix(
|
| 249 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 250 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 251 |
+
TensorCoord extent, ///< Extent of tensor
|
| 252 |
+
int thread_id ///< ID of each participating thread
|
| 253 |
+
)
|
| 254 |
+
: PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id,
|
| 255 |
+
make_Coord(0, 0)) {}
|
| 256 |
+
|
| 257 |
+
/// Adds a pointer offset in units of Element
|
| 258 |
+
CUTLASS_HOST_DEVICE
|
| 259 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 260 |
+
address_iterator_.add_pointer_offset(pointer_offset);
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
/// Advances to the next tile in memory.
|
| 264 |
+
///
|
| 265 |
+
/// The first time this method is called, predicates are updated, and the
|
| 266 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 267 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 268 |
+
/// pointer.
|
| 269 |
+
CUTLASS_HOST_DEVICE
|
| 270 |
+
PredicatedTileIteratorTriangularMatrix &operator++() {
|
| 271 |
+
if (kAdvanceRank)
|
| 272 |
+
address_iterator_.add_tile_offset({0, 1});
|
| 273 |
+
else
|
| 274 |
+
address_iterator_.add_tile_offset({1, 0});
|
| 275 |
+
|
| 276 |
+
return *this;
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
/// Advances to the next tile in memory.
|
| 280 |
+
///
|
| 281 |
+
/// The first time this method is called, predicates are updated, and the
|
| 282 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 283 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 284 |
+
/// pointer.
|
| 285 |
+
CUTLASS_HOST_DEVICE
|
| 286 |
+
PredicatedTileIteratorTriangularMatrix operator++(int) {
|
| 287 |
+
PredicatedTileIteratorTriangularMatrix self(*this);
|
| 288 |
+
operator++();
|
| 289 |
+
return self;
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
/// Clears the predicate set efficiently
|
| 293 |
+
CUTLASS_HOST_DEVICE
|
| 294 |
+
void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
|
| 295 |
+
|
| 296 |
+
/// Clears the predicate set efficiently
|
| 297 |
+
CUTLASS_HOST_DEVICE
|
| 298 |
+
void enable_mask() { address_iterator_.enable_mask(); }
|
| 299 |
+
|
| 300 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 301 |
+
CUTLASS_HOST_DEVICE
|
| 302 |
+
void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); }
|
| 303 |
+
|
| 304 |
+
/// Gets the mask
|
| 305 |
+
CUTLASS_HOST_DEVICE
|
| 306 |
+
void get_mask(Mask &mask) { address_iterator_.get_mask(mask); }
|
| 307 |
+
|
| 308 |
+
CUTLASS_DEVICE
|
| 309 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 310 |
+
load_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
CUTLASS_DEVICE
|
| 314 |
+
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
| 315 |
+
|
| 316 |
+
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
| 317 |
+
|
| 318 |
+
CUTLASS_PRAGMA_UNROLL
|
| 319 |
+
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
| 320 |
+
CUTLASS_PRAGMA_UNROLL
|
| 321 |
+
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
| 322 |
+
|
| 323 |
+
CUTLASS_PRAGMA_UNROLL
|
| 324 |
+
for (int v = 0; v < kAccessesPerVector; ++v) {
|
| 325 |
+
|
| 326 |
+
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
| 327 |
+
|
| 328 |
+
address_iterator_.set_iteration_index(idx);
|
| 329 |
+
char const *byte_ptr = reinterpret_cast<char const *>(address_iterator_.get()) + byte_offset;
|
| 330 |
+
|
| 331 |
+
AccessType const *access_ptr = reinterpret_cast<AccessType const *>(byte_ptr);
|
| 332 |
+
|
| 333 |
+
cutlass::arch::global_load<AccessType,
|
| 334 |
+
sizeof(AccessType)
|
| 335 |
+
>(
|
| 336 |
+
frag_ptr[idx], access_ptr, address_iterator_.valid());
|
| 337 |
+
|
| 338 |
+
++address_iterator_;
|
| 339 |
+
}
|
| 340 |
+
}
|
| 341 |
+
}
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
/// Loads a fragment from memory
|
| 345 |
+
CUTLASS_DEVICE
|
| 346 |
+
void load(Fragment &frag) { load_with_byte_offset(frag, 0); }
|
| 347 |
+
|
| 348 |
+
/// Store a fragment to memory
|
| 349 |
+
CUTLASS_DEVICE
|
| 350 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 351 |
+
store_with_byte_offset(frag, pointer_offset * sizeof_bits<Element>::value / 8);
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
/// Store a fragment to memory
|
| 355 |
+
CUTLASS_DEVICE
|
| 356 |
+
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
| 357 |
+
address_iterator_.set_iteration_index(0);
|
| 358 |
+
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
| 359 |
+
|
| 360 |
+
CUTLASS_PRAGMA_UNROLL
|
| 361 |
+
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
|
| 362 |
+
CUTLASS_PRAGMA_UNROLL
|
| 363 |
+
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
|
| 364 |
+
CUTLASS_PRAGMA_UNROLL
|
| 365 |
+
for (int v = 0; v < kAccessesPerVector; ++v) {
|
| 366 |
+
|
| 367 |
+
int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
|
| 368 |
+
|
| 369 |
+
char *byte_ptr = reinterpret_cast<char *>(address_iterator_.get()) + byte_offset;
|
| 370 |
+
AccessType *access_ptr = reinterpret_cast<AccessType *>(byte_ptr);
|
| 371 |
+
|
| 372 |
+
if (address_iterator_.valid()) {
|
| 373 |
+
*access_ptr = frag_ptr[idx];
|
| 374 |
+
}
|
| 375 |
+
++address_iterator_;
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
}
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
/// Store a fragment to memory
|
| 382 |
+
CUTLASS_DEVICE
|
| 383 |
+
void store(Fragment const &frag) { store_with_byte_offset(frag, 0); }
|
| 384 |
+
};
|
| 385 |
+
|
| 386 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 387 |
+
|
| 388 |
+
/// Specialization of PredicatedTileIteratorTriangularMatrix for column-major data.
|
| 389 |
+
///
|
| 390 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 391 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 392 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 393 |
+
/// MaskedTileIteratorConcept
|
| 394 |
+
///
|
| 395 |
+
template <
|
| 396 |
+
typename Shape_,
|
| 397 |
+
typename Element_,
|
| 398 |
+
int AdvanceRank,
|
| 399 |
+
typename ThreadMap_,
|
| 400 |
+
SideMode kSideMode,
|
| 401 |
+
FillMode kFillMode,
|
| 402 |
+
DiagType kDiagType,
|
| 403 |
+
int AccessSize
|
| 404 |
+
>
|
| 405 |
+
class PredicatedTileIteratorTriangularMatrix<Shape_, Element_, layout::ColumnMajor, AdvanceRank, ThreadMap_,
|
| 406 |
+
kSideMode, kFillMode, kDiagType,
|
| 407 |
+
AccessSize> {
|
| 408 |
+
public:
|
| 409 |
+
|
| 410 |
+
static_assert(AdvanceRank == 0 || AdvanceRank == 1,
|
| 411 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 412 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 413 |
+
|
| 414 |
+
using Shape = Shape_;
|
| 415 |
+
using Element = Element_;
|
| 416 |
+
using Layout = layout::ColumnMajor;
|
| 417 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 418 |
+
using ThreadMap = ThreadMap_;
|
| 419 |
+
|
| 420 |
+
using Index = typename Layout::Index;
|
| 421 |
+
using LongIndex = typename Layout::LongIndex;
|
| 422 |
+
|
| 423 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 424 |
+
using TensorView = TensorView<Element, Layout>;
|
| 425 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 426 |
+
|
| 427 |
+
using Pointer = Element *;
|
| 428 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 429 |
+
|
| 430 |
+
using UnderlyingIterator = PredicatedTileIteratorTriangularMatrix<
|
| 431 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
|
| 432 |
+
Element,
|
| 433 |
+
layout::PitchLinear,
|
| 434 |
+
(kAdvanceRank == 0 ? 0 : 1),
|
| 435 |
+
ThreadMap,
|
| 436 |
+
kSideMode,
|
| 437 |
+
kFillMode,
|
| 438 |
+
kDiagType,
|
| 439 |
+
AccessSize
|
| 440 |
+
>;
|
| 441 |
+
|
| 442 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 443 |
+
|
| 444 |
+
/// Fragment object to be loaded or stored
|
| 445 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
| 446 |
+
|
| 447 |
+
/// Predicate vector stores mask to guard accesses
|
| 448 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 449 |
+
|
| 450 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 451 |
+
class Params {
|
| 452 |
+
private:
|
| 453 |
+
|
| 454 |
+
friend PredicatedTileIteratorTriangularMatrix;
|
| 455 |
+
|
| 456 |
+
/// Parameters object
|
| 457 |
+
typename UnderlyingIterator::Params params_;
|
| 458 |
+
|
| 459 |
+
public:
|
| 460 |
+
|
| 461 |
+
CUTLASS_HOST_DEVICE
|
| 462 |
+
Params() { }
|
| 463 |
+
|
| 464 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 465 |
+
CUTLASS_HOST_DEVICE
|
| 466 |
+
Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {
|
| 467 |
+
|
| 468 |
+
}
|
| 469 |
+
};
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
private:
|
| 473 |
+
|
| 474 |
+
//
|
| 475 |
+
// Data members
|
| 476 |
+
//
|
| 477 |
+
|
| 478 |
+
/// Underlying pitch-linear tile iterator
|
| 479 |
+
UnderlyingIterator iterator_;
|
| 480 |
+
|
| 481 |
+
public:
|
| 482 |
+
|
| 483 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
|
| 484 |
+
CUTLASS_HOST_DEVICE
|
| 485 |
+
PredicatedTileIteratorTriangularMatrix(
|
| 486 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 487 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 488 |
+
TensorCoord extent, ///< Extent of tensor
|
| 489 |
+
int thread_id, ///< ID of each participating thread
|
| 490 |
+
TensorCoord const &threadblock_offset ///< Initial offset of threadblock
|
| 491 |
+
):
|
| 492 |
+
iterator_(
|
| 493 |
+
params.params_,
|
| 494 |
+
pointer,
|
| 495 |
+
layout::PitchLinearCoord(extent.row(), extent.column()),
|
| 496 |
+
thread_id,
|
| 497 |
+
layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column())
|
| 498 |
+
) { }
|
| 499 |
+
|
| 500 |
+
/// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset
|
| 501 |
+
CUTLASS_HOST_DEVICE
|
| 502 |
+
PredicatedTileIteratorTriangularMatrix(
|
| 503 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 504 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 505 |
+
TensorCoord extent, ///< Extent of tensor
|
| 506 |
+
int thread_id ///< ID of each participating thread
|
| 507 |
+
): PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
|
| 508 |
+
|
| 509 |
+
/// Adds a pointer offset in units of Element
|
| 510 |
+
CUTLASS_HOST_DEVICE
|
| 511 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 512 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
/// Advances to the next tile in memory.
|
| 516 |
+
///
|
| 517 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 518 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 519 |
+
/// are lightweight and must only update the internal pointer.
|
| 520 |
+
CUTLASS_HOST_DEVICE
|
| 521 |
+
PredicatedTileIteratorTriangularMatrix &operator++() {
|
| 522 |
+
++iterator_;
|
| 523 |
+
return *this;
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
/// Advances to the next tile in memory.
|
| 527 |
+
///
|
| 528 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 529 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 530 |
+
/// are lightweight and must only update the internal pointer.
|
| 531 |
+
CUTLASS_HOST_DEVICE
|
| 532 |
+
PredicatedTileIteratorTriangularMatrix operator++(int) {
|
| 533 |
+
PredicatedTileIteratorTriangularMatrix self(*this);
|
| 534 |
+
operator++();
|
| 535 |
+
return self;
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
/// Clears the predicate set efficiently
|
| 539 |
+
CUTLASS_HOST_DEVICE
|
| 540 |
+
void clear_mask(bool enable = true) {
|
| 541 |
+
iterator_.clear_mask(enable);
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
/// Clears the predicate set efficiently
|
| 545 |
+
CUTLASS_HOST_DEVICE
|
| 546 |
+
void enable_mask() {
|
| 547 |
+
iterator_.enable_mask();
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 551 |
+
CUTLASS_HOST_DEVICE
|
| 552 |
+
void set_mask(Mask const &mask) {
|
| 553 |
+
iterator_.set_mask(mask);
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
/// Gets the mask
|
| 557 |
+
CUTLASS_HOST_DEVICE
|
| 558 |
+
void get_mask(Mask &mask) {
|
| 559 |
+
iterator_.get_mask(mask);
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
/// Loads a fragment from memory
|
| 563 |
+
CUTLASS_DEVICE
|
| 564 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 565 |
+
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
/// Loads a fragment from memory
|
| 569 |
+
CUTLASS_DEVICE
|
| 570 |
+
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
| 571 |
+
iterator_.load_with_byte_offset(frag, byte_offset);
|
| 572 |
+
}
|
| 573 |
+
|
| 574 |
+
/// Loads a fragment from memory
|
| 575 |
+
CUTLASS_DEVICE
|
| 576 |
+
void load(Fragment &frag) {
|
| 577 |
+
load_with_pointer_offset(frag, 0);
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
/// Store a fragment to memory
|
| 581 |
+
CUTLASS_DEVICE
|
| 582 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 583 |
+
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
/// Store a fragment to memory
|
| 587 |
+
CUTLASS_DEVICE
|
| 588 |
+
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
| 589 |
+
iterator_.store_with_byte_offset(frag, byte_offset);
|
| 590 |
+
}
|
| 591 |
+
|
| 592 |
+
/// Store a fragment to memory
|
| 593 |
+
CUTLASS_DEVICE
|
| 594 |
+
void store(Fragment const &frag) {
|
| 595 |
+
store_with_pointer_offset(frag, 0);
|
| 596 |
+
}
|
| 597 |
+
};
|
| 598 |
+
|
| 599 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 600 |
+
|
| 601 |
+
/// Specialization of PredicatedTileIteratorTriangularMatrix for row-major data.
|
| 602 |
+
///
|
| 603 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 604 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 605 |
+
/// WriteableContiguousTileIteratorConcept |
|
| 606 |
+
/// MaskedTileIteratorConcept
|
| 607 |
+
///
|
| 608 |
+
template <
|
| 609 |
+
typename Shape_,
|
| 610 |
+
typename Element_,
|
| 611 |
+
int AdvanceRank,
|
| 612 |
+
typename ThreadMap_,
|
| 613 |
+
SideMode kSideMode,
|
| 614 |
+
FillMode kFillMode,
|
| 615 |
+
DiagType kDiagType,
|
| 616 |
+
int AccessSize
|
| 617 |
+
>
|
| 618 |
+
class PredicatedTileIteratorTriangularMatrix<Shape_, Element_, layout::RowMajor, AdvanceRank, ThreadMap_,
|
| 619 |
+
kSideMode, kFillMode, kDiagType,
|
| 620 |
+
AccessSize> {
|
| 621 |
+
public:
|
| 622 |
+
|
| 623 |
+
static_assert(AdvanceRank == 0 || AdvanceRank == 1,
|
| 624 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 625 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 626 |
+
|
| 627 |
+
using Shape = Shape_;
|
| 628 |
+
using Element = Element_;
|
| 629 |
+
using Layout = layout::RowMajor;
|
| 630 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 631 |
+
using ThreadMap = ThreadMap_;
|
| 632 |
+
|
| 633 |
+
using Index = typename Layout::Index;
|
| 634 |
+
using LongIndex = typename Layout::LongIndex;
|
| 635 |
+
|
| 636 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 637 |
+
using TensorView = TensorView<Element, Layout>;
|
| 638 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 639 |
+
|
| 640 |
+
using Pointer = Element *;
|
| 641 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 642 |
+
|
| 643 |
+
using UnderlyingIterator = PredicatedTileIteratorTriangularMatrix<
|
| 644 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
|
| 645 |
+
Element,
|
| 646 |
+
layout::PitchLinear,
|
| 647 |
+
(kAdvanceRank == 0 ? 1 : 0),
|
| 648 |
+
ThreadMap,
|
| 649 |
+
kSideMode,
|
| 650 |
+
kFillMode,
|
| 651 |
+
kDiagType,
|
| 652 |
+
AccessSize
|
| 653 |
+
>;
|
| 654 |
+
|
| 655 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 656 |
+
|
| 657 |
+
/// Fragment object to be loaded or stored
|
| 658 |
+
using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
|
| 659 |
+
|
| 660 |
+
/// Predicate vector stores mask to guard accesses
|
| 661 |
+
using Mask = typename UnderlyingIterator::Mask;
|
| 662 |
+
|
| 663 |
+
/// Parameters object is precomputed state and is host-constructible
|
| 664 |
+
class Params {
|
| 665 |
+
private:
|
| 666 |
+
|
| 667 |
+
friend PredicatedTileIteratorTriangularMatrix;
|
| 668 |
+
|
| 669 |
+
/// Parameters object
|
| 670 |
+
typename UnderlyingIterator::Params params_;
|
| 671 |
+
|
| 672 |
+
public:
|
| 673 |
+
|
| 674 |
+
CUTLASS_HOST_DEVICE
|
| 675 |
+
Params() { }
|
| 676 |
+
|
| 677 |
+
/// Construct the Params object given a pitch-linear tensor's layout
|
| 678 |
+
CUTLASS_HOST_DEVICE
|
| 679 |
+
Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {
|
| 680 |
+
|
| 681 |
+
};
|
| 682 |
+
};
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
private:
|
| 686 |
+
|
| 687 |
+
//
|
| 688 |
+
// Data members
|
| 689 |
+
//
|
| 690 |
+
|
| 691 |
+
/// Underlying pitch-linear tile iterator
|
| 692 |
+
UnderlyingIterator iterator_;
|
| 693 |
+
|
| 694 |
+
public:
|
| 695 |
+
|
| 696 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID
|
| 697 |
+
CUTLASS_HOST_DEVICE
|
| 698 |
+
PredicatedTileIteratorTriangularMatrix(
|
| 699 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 700 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 701 |
+
TensorCoord extent, ///< Extent of tensor
|
| 702 |
+
int thread_id, ///< ID of each participating thread
|
| 703 |
+
TensorCoord const &threadblock_offset ///< Initial offset of threadblock
|
| 704 |
+
):
|
| 705 |
+
iterator_(
|
| 706 |
+
params.params_,
|
| 707 |
+
pointer,
|
| 708 |
+
layout::PitchLinearCoord(extent.column(), extent.row()),
|
| 709 |
+
thread_id,
|
| 710 |
+
layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row())
|
| 711 |
+
) { }
|
| 712 |
+
|
| 713 |
+
/// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset
|
| 714 |
+
CUTLASS_HOST_DEVICE
|
| 715 |
+
PredicatedTileIteratorTriangularMatrix(
|
| 716 |
+
Params const ¶ms, ///< Precomputed parameters object
|
| 717 |
+
Pointer pointer, ///< Pointer to start of tensor
|
| 718 |
+
TensorCoord extent, ///< Extent of tensor
|
| 719 |
+
int thread_id ///< ID of each participating thread
|
| 720 |
+
): PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
|
| 721 |
+
|
| 722 |
+
/// Adds a pointer offset in units of Element
|
| 723 |
+
CUTLASS_HOST_DEVICE
|
| 724 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 725 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 726 |
+
}
|
| 727 |
+
|
| 728 |
+
/// Advances to the next tile in memory.
|
| 729 |
+
///
|
| 730 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 731 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 732 |
+
/// are lightweight and must only update the internal pointer.
|
| 733 |
+
CUTLASS_HOST_DEVICE
|
| 734 |
+
PredicatedTileIteratorTriangularMatrix &operator++() {
|
| 735 |
+
++iterator_;
|
| 736 |
+
return *this;
|
| 737 |
+
}
|
| 738 |
+
|
| 739 |
+
/// Advances to the next tile in memory.
|
| 740 |
+
///
|
| 741 |
+
/// The first time this method is called, predicates are updated, and the iterator's
|
| 742 |
+
/// internal pointer is reverted to the first "steady state" tile. Subsequent calls
|
| 743 |
+
/// are lightweight and must only update the internal pointer.
|
| 744 |
+
CUTLASS_HOST_DEVICE
|
| 745 |
+
PredicatedTileIteratorTriangularMatrix operator++(int) {
|
| 746 |
+
PredicatedTileIteratorTriangularMatrix self(*this);
|
| 747 |
+
operator++();
|
| 748 |
+
return self;
|
| 749 |
+
}
|
| 750 |
+
|
| 751 |
+
/// Clears the predicate set efficiently
|
| 752 |
+
CUTLASS_HOST_DEVICE
|
| 753 |
+
void clear_mask(bool enable = true) {
|
| 754 |
+
iterator_.clear_mask(enable);
|
| 755 |
+
}
|
| 756 |
+
|
| 757 |
+
/// Clears the predicate set efficiently
|
| 758 |
+
CUTLASS_HOST_DEVICE
|
| 759 |
+
void enable_mask() {
|
| 760 |
+
iterator_.enable_mask();
|
| 761 |
+
}
|
| 762 |
+
|
| 763 |
+
/// Sets the predicate mask, overriding value stored in predicate iterator
|
| 764 |
+
CUTLASS_HOST_DEVICE
|
| 765 |
+
void set_mask(Mask const &mask) {
|
| 766 |
+
iterator_.set_mask(mask);
|
| 767 |
+
}
|
| 768 |
+
|
| 769 |
+
/// Gets the mask
|
| 770 |
+
CUTLASS_HOST_DEVICE
|
| 771 |
+
void get_mask(Mask &mask) {
|
| 772 |
+
iterator_.get_mask(mask);
|
| 773 |
+
}
|
| 774 |
+
|
| 775 |
+
/// Loads a fragment from memory
|
| 776 |
+
CUTLASS_DEVICE
|
| 777 |
+
void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
|
| 778 |
+
iterator_.load_with_pointer_offset(frag, pointer_offset);
|
| 779 |
+
}
|
| 780 |
+
|
| 781 |
+
/// Loads a fragment from memory
|
| 782 |
+
CUTLASS_DEVICE
|
| 783 |
+
void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) {
|
| 784 |
+
iterator_.load_with_byte_offset(frag, byte_offset);
|
| 785 |
+
}
|
| 786 |
+
|
| 787 |
+
/// Loads a fragment from memory
|
| 788 |
+
CUTLASS_DEVICE
|
| 789 |
+
void load(Fragment &frag) {
|
| 790 |
+
load_with_pointer_offset(frag, 0);
|
| 791 |
+
}
|
| 792 |
+
|
| 793 |
+
/// Store a fragment to memory
|
| 794 |
+
CUTLASS_DEVICE
|
| 795 |
+
void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
|
| 796 |
+
iterator_.store_with_pointer_offset(frag, pointer_offset);
|
| 797 |
+
}
|
| 798 |
+
|
| 799 |
+
/// Store a fragment to memory
|
| 800 |
+
CUTLASS_DEVICE
|
| 801 |
+
void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) {
|
| 802 |
+
iterator_.store_with_byte_offset(frag, byte_offset);
|
| 803 |
+
}
|
| 804 |
+
|
| 805 |
+
/// Store a fragment to memory
|
| 806 |
+
CUTLASS_DEVICE
|
| 807 |
+
void store(Fragment const &frag) {
|
| 808 |
+
store_with_pointer_offset(frag, 0);
|
| 809 |
+
}
|
| 810 |
+
};
|
| 811 |
+
|
| 812 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 813 |
+
|
| 814 |
+
} // namespace threadblock
|
| 815 |
+
} // namespace transform
|
| 816 |
+
} // namespace cutlass
|
| 817 |
+
|
| 818 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Templates implementing computing the addresses of loading small
|
| 34 |
+
vectors from the global memory.
|
| 35 |
+
*/
|
| 36 |
+
|
| 37 |
+
#pragma once
|
| 38 |
+
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
#include "cutlass/array.h"
|
| 41 |
+
#include "cutlass/coord.h"
|
| 42 |
+
#include "cutlass/layout/pitch_linear.h"
|
| 43 |
+
#include "cutlass/layout/matrix.h"
|
| 44 |
+
#include "cutlass/layout/tensor.h"
|
| 45 |
+
#include "cutlass/matrix_coord.h"
|
| 46 |
+
#include "cutlass/matrix_shape.h"
|
| 47 |
+
#include "cutlass/tensor_ref.h"
|
| 48 |
+
|
| 49 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 50 |
+
|
| 51 |
+
namespace cutlass {
|
| 52 |
+
namespace transform {
|
| 53 |
+
namespace threadblock {
|
| 54 |
+
|
| 55 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 56 |
+
|
| 57 |
+
/// PredicatedVectorAccessIterator
|
| 58 |
+
///
|
| 59 |
+
template <
|
| 60 |
+
/// Shape of the vector accessed by the entire threadblock
|
| 61 |
+
typename Shape,
|
| 62 |
+
/// Shape of the vector accessed by the warp
|
| 63 |
+
typename WarpShape,
|
| 64 |
+
/// Type of Element
|
| 65 |
+
typename Element,
|
| 66 |
+
/// Layout of the vector
|
| 67 |
+
typename Layout,
|
| 68 |
+
/// Number of elements for each access
|
| 69 |
+
int ElementsPerAccess,
|
| 70 |
+
/// Support residual tile
|
| 71 |
+
bool EnableResidualAccess = false
|
| 72 |
+
>
|
| 73 |
+
class PredicatedVectorAccessIterator;
|
| 74 |
+
|
| 75 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 76 |
+
|
| 77 |
+
/// Vector access iterator specialized for vectors, e.g. scale and bias
|
| 78 |
+
/// Thread arrangements are for TensorOps
|
| 79 |
+
///
|
| 80 |
+
template <
|
| 81 |
+
typename Shape_,
|
| 82 |
+
typename WarpShape_,
|
| 83 |
+
typename Element_,
|
| 84 |
+
int ElementsPerAccess,
|
| 85 |
+
bool EnableResidualAccess
|
| 86 |
+
>
|
| 87 |
+
class PredicatedVectorAccessIterator <
|
| 88 |
+
Shape_,
|
| 89 |
+
WarpShape_,
|
| 90 |
+
Element_,
|
| 91 |
+
layout::PitchLinear,
|
| 92 |
+
ElementsPerAccess,
|
| 93 |
+
EnableResidualAccess
|
| 94 |
+
> {
|
| 95 |
+
public:
|
| 96 |
+
|
| 97 |
+
using Shape = Shape_;
|
| 98 |
+
using WarpShape = WarpShape_;
|
| 99 |
+
using Element = Element_;
|
| 100 |
+
using Layout = layout::PitchLinear;
|
| 101 |
+
|
| 102 |
+
using Index = typename Layout::Index;
|
| 103 |
+
using LongIndex = typename Layout::LongIndex;
|
| 104 |
+
|
| 105 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 106 |
+
using TensorView = TensorView<Element, Layout>;
|
| 107 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 108 |
+
|
| 109 |
+
using ConstPointer = const Element *;
|
| 110 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 111 |
+
|
| 112 |
+
// static int const kElementsPerAccess = 128 / sizeof_bits<Element>::value;
|
| 113 |
+
static int const kElementsPerAccess = ElementsPerAccess;
|
| 114 |
+
static int const kThreads = 32;
|
| 115 |
+
static int const kRowsPerIteration = 8;
|
| 116 |
+
static int const kThreadsPerRow = kThreads / kRowsPerIteration;
|
| 117 |
+
static int const kThreadsPerRowMask = 0x3;
|
| 118 |
+
static int const kIterations = WarpShape::kContiguous / (kThreadsPerRow * kElementsPerAccess);
|
| 119 |
+
static int const kWarpCountStrided = Shape::kStrided / WarpShape::kStrided;
|
| 120 |
+
|
| 121 |
+
using AccessType = AlignedArray<Element, kElementsPerAccess>;
|
| 122 |
+
|
| 123 |
+
private:
|
| 124 |
+
/// Internal pointer type permits fast address arithmetic
|
| 125 |
+
using BytePointer = char *;
|
| 126 |
+
|
| 127 |
+
private:
|
| 128 |
+
//
|
| 129 |
+
// Data members
|
| 130 |
+
//
|
| 131 |
+
|
| 132 |
+
/// Internal pointer to first access of tile
|
| 133 |
+
BytePointer pointer_;
|
| 134 |
+
|
| 135 |
+
/// Extent of tensor
|
| 136 |
+
TensorCoord extent_;
|
| 137 |
+
|
| 138 |
+
/// pointer offset of each thread
|
| 139 |
+
TensorCoord thread_offset_;
|
| 140 |
+
|
| 141 |
+
/// iteration index
|
| 142 |
+
LongIndex iteration_;
|
| 143 |
+
|
| 144 |
+
/// residual access
|
| 145 |
+
bool is_residual_;
|
| 146 |
+
|
| 147 |
+
/// residual offset of each thread
|
| 148 |
+
TensorCoord residual_offset_;
|
| 149 |
+
|
| 150 |
+
public:
|
| 151 |
+
/// Constructs a vector access iterator
|
| 152 |
+
CUTLASS_HOST_DEVICE
|
| 153 |
+
PredicatedVectorAccessIterator(
|
| 154 |
+
/// Pointer to the start of the vector
|
| 155 |
+
ConstPointer pointer,
|
| 156 |
+
/// Extent of vector
|
| 157 |
+
TensorCoord extent,
|
| 158 |
+
/// ID of each participating thread
|
| 159 |
+
int thread_id,
|
| 160 |
+
/// ID of each participating warp
|
| 161 |
+
int warp_id,
|
| 162 |
+
/// Initial offset of threadblock
|
| 163 |
+
TensorCoord const &threadblock_offset)
|
| 164 |
+
: pointer_(reinterpret_cast<BytePointer>(
|
| 165 |
+
const_cast<NonConstPointer>(pointer))),
|
| 166 |
+
extent_(extent),
|
| 167 |
+
is_residual_(false) {
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
int warp_offset = (warp_id / kWarpCountStrided) * WarpShape::kContiguous;
|
| 171 |
+
|
| 172 |
+
// Per-thread offset in logical coordinates of tensor
|
| 173 |
+
|
| 174 |
+
thread_offset_ = threadblock_offset + TensorCoord(warp_offset, 0) +
|
| 175 |
+
TensorCoord((thread_id & kThreadsPerRowMask) * kElementsPerAccess, 0);
|
| 176 |
+
|
| 177 |
+
set_iteration_index(0);
|
| 178 |
+
|
| 179 |
+
if(EnableResidualAccess) {
|
| 180 |
+
// compute residual offset
|
| 181 |
+
typename TensorCoord::Index residual_size = extent_.contiguous() % WarpShape::kContiguous;
|
| 182 |
+
if (residual_size) {
|
| 183 |
+
is_residual_ = true;
|
| 184 |
+
residual_offset_ = make_Coord(residual_size, 0);
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
/// Construct a PredicatedVectorAccessIterator with zero threadblock offset
|
| 190 |
+
CUTLASS_HOST_DEVICE
|
| 191 |
+
PredicatedVectorAccessIterator(
|
| 192 |
+
/// Pointer to start of vector
|
| 193 |
+
ConstPointer pointer,
|
| 194 |
+
/// Extent of vector
|
| 195 |
+
TensorCoord extent,
|
| 196 |
+
///< ID of each participating thread
|
| 197 |
+
int thread_id,
|
| 198 |
+
/// ID of each participating warp
|
| 199 |
+
int warp_id)
|
| 200 |
+
: PredicatedVectorAccessIterator(pointer, extent, thread_id, warp_id,
|
| 201 |
+
make_Coord(0, 0)) {}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
/// Overrides the internal iteration index
|
| 205 |
+
CUTLASS_HOST_DEVICE
|
| 206 |
+
void set_iteration_index(int index) {
|
| 207 |
+
iteration_ = index;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole tiles
|
| 211 |
+
CUTLASS_DEVICE
|
| 212 |
+
void add_tile_offset(
|
| 213 |
+
TensorCoord const &tile_offset) {
|
| 214 |
+
|
| 215 |
+
thread_offset_ =
|
| 216 |
+
thread_offset_ +
|
| 217 |
+
TensorCoord(WarpShape::kContiguous * tile_offset.contiguous(), 0);
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
/// Returns a pointer
|
| 221 |
+
CUTLASS_HOST_DEVICE
|
| 222 |
+
AccessType *get() const {
|
| 223 |
+
|
| 224 |
+
return reinterpret_cast<AccessType *>(
|
| 225 |
+
pointer_ +
|
| 226 |
+
((thread_offset_.contiguous() + iteration_ * kThreadsPerRow * kElementsPerAccess)
|
| 227 |
+
* sizeof_bits<Element>::value / 8));
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
/// Increment and return an instance to self.
|
| 231 |
+
CUTLASS_HOST_DEVICE
|
| 232 |
+
PredicatedVectorAccessIterator &operator++() {
|
| 233 |
+
++iteration_;
|
| 234 |
+
if(iteration_ >= kIterations)
|
| 235 |
+
iteration_ = 0;
|
| 236 |
+
|
| 237 |
+
return *this;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
/// Increment and return an instance to self.
|
| 241 |
+
CUTLASS_HOST_DEVICE
|
| 242 |
+
void advance() {
|
| 243 |
+
if(EnableResidualAccess && is_residual_) {
|
| 244 |
+
is_residual_ = false;
|
| 245 |
+
thread_offset_ += residual_offset_;
|
| 246 |
+
}
|
| 247 |
+
else
|
| 248 |
+
add_tile_offset(TensorCoord(1, 0));
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
/// Increment and return an instance to self.
|
| 252 |
+
CUTLASS_HOST_DEVICE
|
| 253 |
+
PredicatedVectorAccessIterator operator++(int) {
|
| 254 |
+
PredicatedVectorAccessIterator self(*this);
|
| 255 |
+
operator++();
|
| 256 |
+
return self;
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
/// Returns whether access is valid or not
|
| 260 |
+
CUTLASS_HOST_DEVICE
|
| 261 |
+
bool valid() {
|
| 262 |
+
return ((thread_offset_.contiguous() +
|
| 263 |
+
iteration_ * kThreadsPerRow * kElementsPerAccess) < extent_.contiguous());
|
| 264 |
+
}
|
| 265 |
+
};
|
| 266 |
+
|
| 267 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 268 |
+
|
| 269 |
+
/// Specialization of PredicatedVectorAccessIterator for row-major data.
|
| 270 |
+
///
|
| 271 |
+
template <
|
| 272 |
+
typename Shape_,
|
| 273 |
+
typename WarpShape_,
|
| 274 |
+
typename Element_,
|
| 275 |
+
int ElementsPerAccess,
|
| 276 |
+
bool EnableResidualAccess
|
| 277 |
+
>
|
| 278 |
+
class PredicatedVectorAccessIterator<
|
| 279 |
+
Shape_,
|
| 280 |
+
WarpShape_,
|
| 281 |
+
Element_,
|
| 282 |
+
layout::RowMajor,
|
| 283 |
+
ElementsPerAccess,
|
| 284 |
+
EnableResidualAccess
|
| 285 |
+
> {
|
| 286 |
+
public:
|
| 287 |
+
|
| 288 |
+
using Shape = Shape_;
|
| 289 |
+
using WarpShape = WarpShape_;
|
| 290 |
+
using Element = Element_;
|
| 291 |
+
using Layout = layout::RowMajor;
|
| 292 |
+
|
| 293 |
+
using Index = typename Layout::Index;
|
| 294 |
+
using LongIndex = typename Layout::LongIndex;
|
| 295 |
+
|
| 296 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 297 |
+
using TensorView = TensorView<Element, Layout>;
|
| 298 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 299 |
+
|
| 300 |
+
using ConstPointer = const Element *;
|
| 301 |
+
using NonConstPointer = typename platform::remove_const<Element>::type *;
|
| 302 |
+
|
| 303 |
+
using UnderlyingIterator = PredicatedVectorAccessIterator<
|
| 304 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
|
| 305 |
+
layout::PitchLinearShape<WarpShape::kColumn, WarpShape::kRow>,
|
| 306 |
+
Element,
|
| 307 |
+
layout::PitchLinear,
|
| 308 |
+
ElementsPerAccess,
|
| 309 |
+
EnableResidualAccess>;
|
| 310 |
+
|
| 311 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 312 |
+
static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess;
|
| 313 |
+
static int const kRowsPerIteration = UnderlyingIterator::kRowsPerIteration;
|
| 314 |
+
static int const kThreads = UnderlyingIterator::kThreads;
|
| 315 |
+
static int const kIterations = UnderlyingIterator::kIterations;
|
| 316 |
+
|
| 317 |
+
private:
|
| 318 |
+
//
|
| 319 |
+
// Data members
|
| 320 |
+
//
|
| 321 |
+
|
| 322 |
+
/// Underlying pitch-linear tile iterator
|
| 323 |
+
UnderlyingIterator iterator_;
|
| 324 |
+
|
| 325 |
+
public:
|
| 326 |
+
/// Constructs a TileIterator from its precomputed state, threadblock offset,
|
| 327 |
+
/// and thread ID
|
| 328 |
+
CUTLASS_HOST_DEVICE
|
| 329 |
+
PredicatedVectorAccessIterator(
|
| 330 |
+
///< Pointer to the start of the vector
|
| 331 |
+
ConstPointer pointer,
|
| 332 |
+
///< Extent of tensor
|
| 333 |
+
TensorCoord extent,
|
| 334 |
+
///< ID of each participating thread
|
| 335 |
+
int thread_id,
|
| 336 |
+
///< ID of each participating warp
|
| 337 |
+
int warp_id,
|
| 338 |
+
///< Initial offset of threadblock
|
| 339 |
+
TensorCoord const &threadblock_offset)
|
| 340 |
+
: iterator_(pointer, layout::PitchLinearCoord(extent.column(), extent.row()),
|
| 341 |
+
thread_id, warp_id,
|
| 342 |
+
layout::PitchLinearCoord(threadblock_offset.column(),
|
| 343 |
+
threadblock_offset.row())) {}
|
| 344 |
+
|
| 345 |
+
/// Construct a PredicatedVectorAccessIterator with zero threadblock offset
|
| 346 |
+
CUTLASS_HOST_DEVICE
|
| 347 |
+
PredicatedVectorAccessIterator(
|
| 348 |
+
ConstPointer pointer, ///< Pointer to the start of the vector
|
| 349 |
+
TensorCoord extent, ///< Extent of tensor
|
| 350 |
+
int thread_id, ///< ID of each participating thread
|
| 351 |
+
int warp_id ///< ID of each participating warp
|
| 352 |
+
)
|
| 353 |
+
: PredicatedVectorAccessIterator(pointer, extent, thread_id, warp_id,
|
| 354 |
+
make_Coord(0, 0)) {}
|
| 355 |
+
|
| 356 |
+
/// Overrides the internal iteration index
|
| 357 |
+
CUTLASS_HOST_DEVICE
|
| 358 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 359 |
+
|
| 360 |
+
/// Advances an iterator along logical dimensions of matrix in units of whole
|
| 361 |
+
/// tiles
|
| 362 |
+
CUTLASS_HOST_DEVICE
|
| 363 |
+
void add_tile_offset(TensorCoord const &tile_offset) {
|
| 364 |
+
iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
/// Returns a pointer
|
| 368 |
+
CUTLASS_HOST_DEVICE
|
| 369 |
+
AccessType *get() const {
|
| 370 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
/// Advances to the next tile in memory.
|
| 374 |
+
///
|
| 375 |
+
/// The first time this method is called, predicates are updated, and the
|
| 376 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 377 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 378 |
+
/// pointer.
|
| 379 |
+
CUTLASS_HOST_DEVICE
|
| 380 |
+
PredicatedVectorAccessIterator &operator++() {
|
| 381 |
+
++iterator_;
|
| 382 |
+
return *this;
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
/// Advances to the next tile in memory.
|
| 386 |
+
///
|
| 387 |
+
/// The first time this method is called, predicates are updated, and the
|
| 388 |
+
/// iterator's internal pointer is reverted to the first "steady state" tile.
|
| 389 |
+
/// Subsequent calls are lightweight and must only update the internal
|
| 390 |
+
/// pointer.
|
| 391 |
+
CUTLASS_HOST_DEVICE
|
| 392 |
+
PredicatedVectorAccessIterator operator++(int) {
|
| 393 |
+
PredicatedVectorAccessIterator self(*this);
|
| 394 |
+
operator++();
|
| 395 |
+
return self;
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
/// Increment and return an instance to self.
|
| 399 |
+
CUTLASS_HOST_DEVICE
|
| 400 |
+
void advance() {
|
| 401 |
+
iterator_.advance();
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
/// Returns whether access is valid or not
|
| 405 |
+
CUTLASS_HOST_DEVICE
|
| 406 |
+
bool valid() {
|
| 407 |
+
return iterator_.valid();
|
| 408 |
+
}
|
| 409 |
+
};
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 413 |
+
|
| 414 |
+
} // namespace threadblock
|
| 415 |
+
} // namespace transform
|
| 416 |
+
} // namespace cutlass
|
| 417 |
+
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Templates implementing computing the addresses of storing of small
|
| 34 |
+
scale and bias vectors in the shared memory.
|
| 35 |
+
*/
|
| 36 |
+
|
| 37 |
+
#pragma once
|
| 38 |
+
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
#include "cutlass/array.h"
|
| 41 |
+
#include "cutlass/layout/pitch_linear.h"
|
| 42 |
+
#include "cutlass/layout/matrix.h"
|
| 43 |
+
#include "cutlass/matrix_coord.h"
|
| 44 |
+
#include "cutlass/matrix_shape.h"
|
| 45 |
+
#include "cutlass/tensor_ref.h"
|
| 46 |
+
|
| 47 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
namespace cutlass {
|
| 50 |
+
namespace transform {
|
| 51 |
+
namespace threadblock {
|
| 52 |
+
|
| 53 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 54 |
+
|
| 55 |
+
/// RegularScaleBiasVectorAccessIterator
|
| 56 |
+
///
|
| 57 |
+
template <typename Shape, typename Element, typename Layout>
|
| 58 |
+
class RegularScaleBiasVectorAccessIterator;
|
| 59 |
+
|
| 60 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 61 |
+
|
| 62 |
+
/// Tile iterator specialized for congruous arrangements for TensorOps
|
| 63 |
+
///
|
| 64 |
+
///
|
| 65 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 66 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 67 |
+
/// WriteableContiguousTileIteratorConcept
|
| 68 |
+
///
|
| 69 |
+
template <typename Shape_, typename Element_>
|
| 70 |
+
class RegularScaleBiasVectorAccessIterator<Shape_, Element_, layout::PitchLinear> {
|
| 71 |
+
public:
|
| 72 |
+
|
| 73 |
+
using Shape = Shape_;
|
| 74 |
+
using Element = Element_;
|
| 75 |
+
using Layout = layout::PitchLinear;
|
| 76 |
+
|
| 77 |
+
using Index = typename Layout::Index;
|
| 78 |
+
using LongIndex = typename Layout::LongIndex;
|
| 79 |
+
|
| 80 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 81 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 82 |
+
|
| 83 |
+
/// Element type per access
|
| 84 |
+
static int const kElementsPerAccess = 128 / sizeof_bits<Element>::value;
|
| 85 |
+
static int const kThreads = Shape::kContiguous / kElementsPerAccess;
|
| 86 |
+
using AccessType = Array<Element, kElementsPerAccess>;
|
| 87 |
+
|
| 88 |
+
private:
|
| 89 |
+
//
|
| 90 |
+
// Data members
|
| 91 |
+
//
|
| 92 |
+
|
| 93 |
+
/// Internal pointer
|
| 94 |
+
AccessType *pointer_;
|
| 95 |
+
|
| 96 |
+
/// Internal byte offset
|
| 97 |
+
Index byte_offset_;
|
| 98 |
+
|
| 99 |
+
public:
|
| 100 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 101 |
+
CUTLASS_HOST_DEVICE
|
| 102 |
+
RegularScaleBiasVectorAccessIterator(
|
| 103 |
+
TensorRef scale_bias_ref, ///< Pointer to the start of the scale and bias
|
| 104 |
+
///< vector
|
| 105 |
+
int thread_id ///< ID of each participating thread
|
| 106 |
+
)
|
| 107 |
+
: byte_offset_(0) {
|
| 108 |
+
// Per-thread offset in logical coordinates of tensor
|
| 109 |
+
int thread_offset = thread_id * kElementsPerAccess;
|
| 110 |
+
|
| 111 |
+
// initialize pointer
|
| 112 |
+
pointer_ =
|
| 113 |
+
reinterpret_cast<AccessType *>(scale_bias_ref.data() + thread_offset);
|
| 114 |
+
|
| 115 |
+
set_iteration_index(0);
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
/// Overrides the internal iteration index
|
| 119 |
+
CUTLASS_HOST_DEVICE
|
| 120 |
+
void set_iteration_index(int index) {}
|
| 121 |
+
|
| 122 |
+
/// Adds a pointer offset in units of Element
|
| 123 |
+
CUTLASS_HOST_DEVICE
|
| 124 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 125 |
+
byte_offset_ += pointer_offset * sizeof(Element);
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
/// Returns a pointer
|
| 129 |
+
CUTLASS_DEVICE
|
| 130 |
+
AccessType *get() const {
|
| 131 |
+
|
| 132 |
+
char *access_byte_ptr =
|
| 133 |
+
reinterpret_cast<char *>(pointer_);
|
| 134 |
+
|
| 135 |
+
return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
/// Advances to the next tile in memory.
|
| 139 |
+
CUTLASS_HOST_DEVICE
|
| 140 |
+
RegularScaleBiasVectorAccessIterator &operator++() { return *this; }
|
| 141 |
+
|
| 142 |
+
/// Advances to the next tile in memory.
|
| 143 |
+
CUTLASS_HOST_DEVICE
|
| 144 |
+
RegularScaleBiasVectorAccessIterator operator++(int) {
|
| 145 |
+
RegularScaleBiasVectorAccessIterator prev(*this);
|
| 146 |
+
this->operator++();
|
| 147 |
+
|
| 148 |
+
return prev;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
/// Adds a tile offset in the unit of tile.
|
| 152 |
+
CUTLASS_DEVICE
|
| 153 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 154 |
+
// Multiply by 2 because we store scale and bias belong to the same stage
|
| 155 |
+
// next to each other.
|
| 156 |
+
add_pointer_offset(coord.contiguous() * Shape::kContiguous * 2);
|
| 157 |
+
}
|
| 158 |
+
};
|
| 159 |
+
|
| 160 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 161 |
+
|
| 162 |
+
/// Tile iterator specialized for row major layouts
|
| 163 |
+
///
|
| 164 |
+
///
|
| 165 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 166 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 167 |
+
/// WriteableContiguousTileIteratorConcept
|
| 168 |
+
///
|
| 169 |
+
template <typename Shape_, typename Element_>
|
| 170 |
+
class RegularScaleBiasVectorAccessIterator<
|
| 171 |
+
Shape_, Element_,
|
| 172 |
+
layout::RowMajor> {
|
| 173 |
+
public:
|
| 174 |
+
|
| 175 |
+
using Shape = Shape_;
|
| 176 |
+
using Element = Element_;
|
| 177 |
+
using Layout = layout::RowMajor;
|
| 178 |
+
|
| 179 |
+
using Index = typename Layout::Index;
|
| 180 |
+
using LongIndex = typename Layout::LongIndex;
|
| 181 |
+
|
| 182 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 183 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 184 |
+
|
| 185 |
+
/// Underlying iterator type
|
| 186 |
+
using UnderlyingIterator = RegularScaleBiasVectorAccessIterator<
|
| 187 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
|
| 188 |
+
layout::PitchLinear>;
|
| 189 |
+
|
| 190 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 191 |
+
|
| 192 |
+
private:
|
| 193 |
+
|
| 194 |
+
/// Underlying iterator
|
| 195 |
+
UnderlyingIterator iterator_;
|
| 196 |
+
|
| 197 |
+
public:
|
| 198 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 199 |
+
CUTLASS_HOST_DEVICE
|
| 200 |
+
RegularScaleBiasVectorAccessIterator(
|
| 201 |
+
TensorRef scale_bias_ref, ///< Pointer to the start of the scale and bias
|
| 202 |
+
///< vector
|
| 203 |
+
int thread_id ///< ID of each participating thread
|
| 204 |
+
)
|
| 205 |
+
: iterator_({scale_bias_ref.data(), scale_bias_ref.stride()}, thread_id) {
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
/// Overrides the internal iteration index
|
| 209 |
+
CUTLASS_HOST_DEVICE
|
| 210 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 211 |
+
|
| 212 |
+
/// Adds a pointer offset in units of Element
|
| 213 |
+
CUTLASS_HOST_DEVICE
|
| 214 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 215 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
/// Returns a pointer
|
| 219 |
+
CUTLASS_HOST_DEVICE
|
| 220 |
+
AccessType *get() const {
|
| 221 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
/// Adds a tile offset
|
| 225 |
+
CUTLASS_DEVICE
|
| 226 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 227 |
+
iterator_.add_tile_offset({coord.column(), coord.row()});
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
/// Advances to the next tile in memory.
|
| 231 |
+
CUTLASS_HOST_DEVICE
|
| 232 |
+
RegularScaleBiasVectorAccessIterator &operator++() {
|
| 233 |
+
++iterator_;
|
| 234 |
+
return *this;
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
/// Advances to the next tile in memory.
|
| 238 |
+
CUTLASS_HOST_DEVICE
|
| 239 |
+
RegularScaleBiasVectorAccessIterator operator++(int) {
|
| 240 |
+
RegularScaleBiasVectorAccessIterator prev(*this);
|
| 241 |
+
++iterator_;
|
| 242 |
+
|
| 243 |
+
return prev;
|
| 244 |
+
}
|
| 245 |
+
};
|
| 246 |
+
|
| 247 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 248 |
+
|
| 249 |
+
} // namespace threadblock
|
| 250 |
+
} // namespace transform
|
| 251 |
+
} // namespace cutlass
|
| 252 |
+
|
| 253 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates implementing the address computation of storing of tiles
|
| 33 |
+
from pitch-linear rank=2 tensors.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
|
| 40 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 41 |
+
|
| 42 |
+
namespace cutlass {
|
| 43 |
+
namespace transform {
|
| 44 |
+
namespace threadblock {
|
| 45 |
+
|
| 46 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 47 |
+
|
| 48 |
+
template <typename Shape, typename Element, typename Layout, int AdvanceRank,
|
| 49 |
+
typename ThreadMap,
|
| 50 |
+
int Alignment =
|
| 51 |
+
sizeof_bits<Element>::value* ThreadMap::kElementsPerAccess / 8>
|
| 52 |
+
class RegularTileAccessIterator;
|
| 53 |
+
|
| 54 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 55 |
+
|
| 56 |
+
} // namespace threadblock
|
| 57 |
+
} // namespace transform
|
| 58 |
+
} // namespace cutlass
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates implementing computing the addresses of storing of tiles
|
| 33 |
+
from pitch-linear rank=2 tensors.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#include "cutlass/array.h"
|
| 40 |
+
#include "cutlass/layout/pitch_linear.h"
|
| 41 |
+
#include "cutlass/layout/matrix.h"
|
| 42 |
+
#include "cutlass/matrix_coord.h"
|
| 43 |
+
#include "cutlass/matrix_shape.h"
|
| 44 |
+
#include "cutlass/tensor_ref.h"
|
| 45 |
+
|
| 46 |
+
#include "cutlass/transform/threadblock/regular_tile_access_iterator.h"
|
| 47 |
+
|
| 48 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
namespace cutlass {
|
| 51 |
+
namespace transform {
|
| 52 |
+
namespace threadblock {
|
| 53 |
+
|
| 54 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 55 |
+
|
| 56 |
+
/// Tile iterator specialized for congruous arrangements for TensorOps
|
| 57 |
+
///
|
| 58 |
+
///
|
| 59 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 60 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 61 |
+
/// WriteableContiguousTileIteratorConcept
|
| 62 |
+
///
|
| 63 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 64 |
+
typename ThreadMap_, int Alignment>
|
| 65 |
+
class RegularTileAccessIterator<
|
| 66 |
+
Shape_, Element_,
|
| 67 |
+
layout::PitchLinear,
|
| 68 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 69 |
+
public:
|
| 70 |
+
static_assert(
|
| 71 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 72 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 73 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 74 |
+
|
| 75 |
+
using Shape = Shape_;
|
| 76 |
+
using Element = Element_;
|
| 77 |
+
using Layout = layout::PitchLinear;
|
| 78 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 79 |
+
static int const kAlignment = Alignment;
|
| 80 |
+
|
| 81 |
+
using Index = typename Layout::Index;
|
| 82 |
+
using LongIndex = typename Layout::LongIndex;
|
| 83 |
+
using StrideIndex = typename Layout::Stride::Index;
|
| 84 |
+
|
| 85 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 86 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 87 |
+
|
| 88 |
+
using ThreadMap = ThreadMap_;
|
| 89 |
+
|
| 90 |
+
/// Element type per access
|
| 91 |
+
using AccessType = Array<Element, ThreadMap::kElementsPerAccess>;
|
| 92 |
+
|
| 93 |
+
private:
|
| 94 |
+
//
|
| 95 |
+
// Data members
|
| 96 |
+
//
|
| 97 |
+
|
| 98 |
+
/// Stride value
|
| 99 |
+
StrideIndex stride_;
|
| 100 |
+
|
| 101 |
+
/// Internal pointer to first access of tile
|
| 102 |
+
AccessType *pointer_;
|
| 103 |
+
|
| 104 |
+
/// Internal byte offset
|
| 105 |
+
Index byte_offset_;
|
| 106 |
+
|
| 107 |
+
/// Iteration in the contiguous dimension
|
| 108 |
+
int iteration_contiguous_;
|
| 109 |
+
|
| 110 |
+
/// Iteration in the strided dimension
|
| 111 |
+
int iteration_strided_;
|
| 112 |
+
|
| 113 |
+
public:
|
| 114 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 115 |
+
CUTLASS_HOST_DEVICE
|
| 116 |
+
RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
|
| 117 |
+
int thread_id ///< ID of each participating thread
|
| 118 |
+
)
|
| 119 |
+
: stride_(ref.stride(0) / ThreadMap::kElementsPerAccess),
|
| 120 |
+
byte_offset_(0) {
|
| 121 |
+
|
| 122 |
+
layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id);
|
| 123 |
+
|
| 124 |
+
// initialize pointer
|
| 125 |
+
pointer_ = reinterpret_cast<AccessType *>(ref.data() + ref.offset(thread_offset_base));
|
| 126 |
+
|
| 127 |
+
set_iteration_index(0);
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
/// Overrides the internal iteration index
|
| 131 |
+
CUTLASS_HOST_DEVICE
|
| 132 |
+
void set_iteration_index(int index) {
|
| 133 |
+
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
| 134 |
+
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
/// Adds a pointer offset in units of Element
|
| 138 |
+
CUTLASS_HOST_DEVICE
|
| 139 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 140 |
+
byte_offset_ += pointer_offset * sizeof(Element);
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
/// Returns a pointer
|
| 144 |
+
CUTLASS_DEVICE
|
| 145 |
+
AccessType *get() const {
|
| 146 |
+
|
| 147 |
+
AccessType *access_ptr = pointer_;
|
| 148 |
+
|
| 149 |
+
int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ +
|
| 150 |
+
iteration_contiguous_ * ThreadMap::Delta::kContiguous /
|
| 151 |
+
ThreadMap::kElementsPerAccess;
|
| 152 |
+
|
| 153 |
+
char *access_byte_ptr =
|
| 154 |
+
reinterpret_cast<char *>(access_ptr + access_offset);
|
| 155 |
+
|
| 156 |
+
return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
/// Advances to the next tile in memory.
|
| 160 |
+
CUTLASS_HOST_DEVICE
|
| 161 |
+
RegularTileAccessIterator &operator++() {
|
| 162 |
+
++iteration_contiguous_;
|
| 163 |
+
|
| 164 |
+
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
|
| 165 |
+
return *this;
|
| 166 |
+
|
| 167 |
+
// Enter here only if (iteration_contiguous_ ==
|
| 168 |
+
// ThreadMap::Iteration::kContiguous)
|
| 169 |
+
iteration_contiguous_ = 0;
|
| 170 |
+
++iteration_strided_;
|
| 171 |
+
|
| 172 |
+
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
| 173 |
+
return *this;
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
// Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
|
| 177 |
+
// which means we enter the next tile.
|
| 178 |
+
iteration_strided_ = 0;
|
| 179 |
+
|
| 180 |
+
return *this;
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
/// Advances to the next tile in memory.
|
| 184 |
+
CUTLASS_HOST_DEVICE
|
| 185 |
+
RegularTileAccessIterator operator++(int) {
|
| 186 |
+
RegularTileAccessIterator prev(*this);
|
| 187 |
+
this->operator++();
|
| 188 |
+
|
| 189 |
+
return prev;
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
/// Adds a tile offset in the unit of tile.
|
| 193 |
+
/// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory.
|
| 194 |
+
/// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B.
|
| 195 |
+
/// For row major A operand, k dimension is contiguous dimension;
|
| 196 |
+
/// For col major A operand, k dimension is strided dimension;
|
| 197 |
+
/// For row major B operand, k dimension is strided dimension;
|
| 198 |
+
/// For col major B operand, k dimension is contiguous dimension.
|
| 199 |
+
/// Below two classes map col/row major to the pitch linear coordinates used
|
| 200 |
+
/// in this base class.
|
| 201 |
+
CUTLASS_DEVICE
|
| 202 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 203 |
+
add_pointer_offset(coord.contiguous() * Shape::kContiguous +
|
| 204 |
+
coord.strided() * Shape::kStrided * stride_ *
|
| 205 |
+
ThreadMap::kElementsPerAccess);
|
| 206 |
+
}
|
| 207 |
+
};
|
| 208 |
+
|
| 209 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 210 |
+
|
| 211 |
+
/// Tile iterator specialized for column major layouts
|
| 212 |
+
///
|
| 213 |
+
///
|
| 214 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 215 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 216 |
+
/// WriteableContiguousTileIteratorConcept
|
| 217 |
+
///
|
| 218 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 219 |
+
typename ThreadMap_, int Alignment>
|
| 220 |
+
class RegularTileAccessIterator<
|
| 221 |
+
Shape_, Element_,
|
| 222 |
+
layout::ColumnMajor,
|
| 223 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 224 |
+
public:
|
| 225 |
+
static_assert(
|
| 226 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 227 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 228 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 229 |
+
|
| 230 |
+
using Shape = Shape_;
|
| 231 |
+
using Element = Element_;
|
| 232 |
+
using Layout = layout::ColumnMajor;
|
| 233 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 234 |
+
static int const kAlignment = Alignment;
|
| 235 |
+
|
| 236 |
+
using Index = typename Layout::Index;
|
| 237 |
+
using LongIndex = typename Layout::LongIndex;
|
| 238 |
+
|
| 239 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 240 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 241 |
+
|
| 242 |
+
using ThreadMap = ThreadMap_;
|
| 243 |
+
|
| 244 |
+
/// Underlying iterator type
|
| 245 |
+
using UnderlyingIterator = RegularTileAccessIterator<
|
| 246 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
|
| 247 |
+
layout::PitchLinear,
|
| 248 |
+
(kAdvanceRank == 0 ? 0 : 1),
|
| 249 |
+
ThreadMap_>;
|
| 250 |
+
|
| 251 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 252 |
+
|
| 253 |
+
private:
|
| 254 |
+
|
| 255 |
+
/// Underlying iterator
|
| 256 |
+
UnderlyingIterator iterator_;
|
| 257 |
+
|
| 258 |
+
public:
|
| 259 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 260 |
+
CUTLASS_HOST_DEVICE
|
| 261 |
+
RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
|
| 262 |
+
int thread_id ///< ID of each participating thread
|
| 263 |
+
)
|
| 264 |
+
: iterator_({ref.data(), ref.stride()}, thread_id) {}
|
| 265 |
+
|
| 266 |
+
/// Overrides the internal iteration index
|
| 267 |
+
CUTLASS_HOST_DEVICE
|
| 268 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 269 |
+
|
| 270 |
+
/// Adds a pointer offset in units of Element
|
| 271 |
+
CUTLASS_HOST_DEVICE
|
| 272 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 273 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
/// Returns a pointer
|
| 277 |
+
CUTLASS_HOST_DEVICE
|
| 278 |
+
AccessType *get() const {
|
| 279 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
/// Adds a tile offset
|
| 283 |
+
CUTLASS_DEVICE
|
| 284 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 285 |
+
iterator_.add_tile_offset({coord.row(), coord.column()});
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
/// Advances to the next tile in memory.
|
| 289 |
+
CUTLASS_HOST_DEVICE
|
| 290 |
+
RegularTileAccessIterator &operator++() {
|
| 291 |
+
++iterator_;
|
| 292 |
+
return *this;
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
/// Advances to the next tile in memory.
|
| 296 |
+
CUTLASS_HOST_DEVICE
|
| 297 |
+
RegularTileAccessIterator operator++(int) {
|
| 298 |
+
RegularTileAccessIterator prev(*this);
|
| 299 |
+
++iterator_;
|
| 300 |
+
|
| 301 |
+
return prev;
|
| 302 |
+
}
|
| 303 |
+
};
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 307 |
+
|
| 308 |
+
/// Tile iterator specialized for row major layouts
|
| 309 |
+
///
|
| 310 |
+
///
|
| 311 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 312 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 313 |
+
/// WriteableContiguousTileIteratorConcept
|
| 314 |
+
///
|
| 315 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 316 |
+
typename ThreadMap_, int Alignment>
|
| 317 |
+
class RegularTileAccessIterator<
|
| 318 |
+
Shape_, Element_,
|
| 319 |
+
layout::RowMajor,
|
| 320 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 321 |
+
public:
|
| 322 |
+
static_assert(
|
| 323 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 324 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 325 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 326 |
+
|
| 327 |
+
using Shape = Shape_;
|
| 328 |
+
using Element = Element_;
|
| 329 |
+
using Layout = layout::RowMajor;
|
| 330 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 331 |
+
static int const kAlignment = Alignment;
|
| 332 |
+
|
| 333 |
+
using Index = typename Layout::Index;
|
| 334 |
+
using LongIndex = typename Layout::LongIndex;
|
| 335 |
+
|
| 336 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 337 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 338 |
+
|
| 339 |
+
using ThreadMap = ThreadMap_;
|
| 340 |
+
|
| 341 |
+
/// Underlying iterator type
|
| 342 |
+
using UnderlyingIterator = RegularTileAccessIterator<
|
| 343 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
|
| 344 |
+
layout::PitchLinear,
|
| 345 |
+
(kAdvanceRank == 0 ? 1 : 0),
|
| 346 |
+
ThreadMap_>;
|
| 347 |
+
|
| 348 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 349 |
+
|
| 350 |
+
private:
|
| 351 |
+
|
| 352 |
+
/// Underlying iterator
|
| 353 |
+
UnderlyingIterator iterator_;
|
| 354 |
+
|
| 355 |
+
public:
|
| 356 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 357 |
+
CUTLASS_HOST_DEVICE
|
| 358 |
+
RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
|
| 359 |
+
int thread_id ///< ID of each participating thread
|
| 360 |
+
)
|
| 361 |
+
: iterator_({ref.data(), ref.stride()}, thread_id) {}
|
| 362 |
+
|
| 363 |
+
/// Overrides the internal iteration index
|
| 364 |
+
CUTLASS_HOST_DEVICE
|
| 365 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 366 |
+
|
| 367 |
+
/// Adds a pointer offset in units of Element
|
| 368 |
+
CUTLASS_HOST_DEVICE
|
| 369 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 370 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
/// Returns a pointer
|
| 374 |
+
CUTLASS_HOST_DEVICE
|
| 375 |
+
AccessType *get() const {
|
| 376 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
/// Adds a tile offset
|
| 380 |
+
CUTLASS_DEVICE
|
| 381 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 382 |
+
iterator_.add_tile_offset({coord.column(), coord.row()});
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
/// Advances to the next tile in memory.
|
| 386 |
+
CUTLASS_HOST_DEVICE
|
| 387 |
+
RegularTileAccessIterator &operator++() {
|
| 388 |
+
++iterator_;
|
| 389 |
+
return *this;
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
/// Advances to the next tile in memory.
|
| 393 |
+
CUTLASS_HOST_DEVICE
|
| 394 |
+
RegularTileAccessIterator operator++(int) {
|
| 395 |
+
RegularTileAccessIterator prev(*this);
|
| 396 |
+
++iterator_;
|
| 397 |
+
|
| 398 |
+
return prev;
|
| 399 |
+
}
|
| 400 |
+
};
|
| 401 |
+
|
| 402 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 403 |
+
|
| 404 |
+
} // namespace threadblock
|
| 405 |
+
} // namespace transform
|
| 406 |
+
} // namespace cutlass
|
| 407 |
+
|
| 408 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates implementing computing the addresses of storing of tiles
|
| 33 |
+
from pitch-linear rank=2 tensors.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#include "cutlass/array.h"
|
| 40 |
+
#include "cutlass/layout/pitch_linear.h"
|
| 41 |
+
#include "cutlass/layout/matrix.h"
|
| 42 |
+
#include "cutlass/matrix_coord.h"
|
| 43 |
+
#include "cutlass/matrix_shape.h"
|
| 44 |
+
#include "cutlass/tensor_ref.h"
|
| 45 |
+
|
| 46 |
+
#include "cutlass/transform/threadblock/regular_tile_access_iterator.h"
|
| 47 |
+
|
| 48 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
namespace cutlass {
|
| 51 |
+
namespace transform {
|
| 52 |
+
namespace threadblock {
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 56 |
+
|
| 57 |
+
template <typename Shape, typename Element, typename Layout, int AdvanceRank,
|
| 58 |
+
typename ThreadMap,
|
| 59 |
+
bool Dynamic_iterations = false,
|
| 60 |
+
int Alignment =
|
| 61 |
+
sizeof_bits<Element>::value* ThreadMap::kElementsPerAccess / 8
|
| 62 |
+
>
|
| 63 |
+
class RegularTileAccessIteratorDirectConv;
|
| 64 |
+
|
| 65 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 66 |
+
|
| 67 |
+
/// Tile iterator specialized for congruous arrangements for TensorOps with dynamic_iterations OFF
|
| 68 |
+
///
|
| 69 |
+
///
|
| 70 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 71 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 72 |
+
/// WriteableContiguousTileIteratorConcept
|
| 73 |
+
///
|
| 74 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 75 |
+
typename ThreadMap_, int Alignment>
|
| 76 |
+
class RegularTileAccessIteratorDirectConv<
|
| 77 |
+
Shape_, Element_,
|
| 78 |
+
layout::PitchLinear,
|
| 79 |
+
AdvanceRank, ThreadMap_, false, Alignment> {
|
| 80 |
+
public:
|
| 81 |
+
static_assert(
|
| 82 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 83 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 84 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 85 |
+
|
| 86 |
+
using Shape = Shape_;
|
| 87 |
+
using Element = Element_;
|
| 88 |
+
using Layout = layout::PitchLinear;
|
| 89 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 90 |
+
static int const kAlignment = Alignment;
|
| 91 |
+
|
| 92 |
+
using Index = typename Layout::Index;
|
| 93 |
+
using LongIndex = typename Layout::LongIndex;
|
| 94 |
+
using StrideIndex = typename Layout::Stride::Index;
|
| 95 |
+
|
| 96 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 97 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 98 |
+
|
| 99 |
+
using ThreadMap = ThreadMap_;
|
| 100 |
+
|
| 101 |
+
/// Element type per access
|
| 102 |
+
using AccessType = Array<Element, ThreadMap::kElementsPerAccess>;
|
| 103 |
+
|
| 104 |
+
private:
|
| 105 |
+
//
|
| 106 |
+
// Data members
|
| 107 |
+
//
|
| 108 |
+
|
| 109 |
+
/// Stride value
|
| 110 |
+
StrideIndex stride_;
|
| 111 |
+
|
| 112 |
+
/// Internal pointer to first access of tile
|
| 113 |
+
AccessType *pointer_;
|
| 114 |
+
|
| 115 |
+
/// Internal byte offset
|
| 116 |
+
Index byte_offset_;
|
| 117 |
+
|
| 118 |
+
/// Iteration in the contiguous dimension
|
| 119 |
+
int iteration_contiguous_;
|
| 120 |
+
|
| 121 |
+
/// Iteration in the strided dimension
|
| 122 |
+
int iteration_strided_;
|
| 123 |
+
|
| 124 |
+
public:
|
| 125 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 126 |
+
CUTLASS_HOST_DEVICE
|
| 127 |
+
RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor
|
| 128 |
+
int thread_id ///< ID of each participating thread
|
| 129 |
+
)
|
| 130 |
+
: stride_(ref.stride(0) / ThreadMap::kElementsPerAccess),
|
| 131 |
+
byte_offset_(0) {
|
| 132 |
+
|
| 133 |
+
layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id);
|
| 134 |
+
|
| 135 |
+
// initialize pointer
|
| 136 |
+
pointer_ = reinterpret_cast<AccessType *>(ref.data() + ref.offset(thread_offset_base));
|
| 137 |
+
|
| 138 |
+
set_iteration_index(0);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
/// Overrides the internal iteration index
|
| 142 |
+
CUTLASS_HOST_DEVICE
|
| 143 |
+
void set_iteration_index(int index) {
|
| 144 |
+
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
| 145 |
+
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
/// Overrides the internal iteration index
|
| 149 |
+
CUTLASS_HOST_DEVICE
|
| 150 |
+
void set_iteration_num(int num) {
|
| 151 |
+
//Do nothing
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
/// Adds a pointer offset in units of Element
|
| 155 |
+
CUTLASS_HOST_DEVICE
|
| 156 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 157 |
+
byte_offset_ += pointer_offset * sizeof(Element);
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
/// Returns a pointer
|
| 161 |
+
CUTLASS_DEVICE
|
| 162 |
+
AccessType *get() const {
|
| 163 |
+
|
| 164 |
+
AccessType *access_ptr = pointer_;
|
| 165 |
+
|
| 166 |
+
int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ +
|
| 167 |
+
iteration_contiguous_ * ThreadMap::Delta::kContiguous /
|
| 168 |
+
ThreadMap::kElementsPerAccess;
|
| 169 |
+
|
| 170 |
+
char *access_byte_ptr =
|
| 171 |
+
reinterpret_cast<char *>(access_ptr + access_offset);
|
| 172 |
+
|
| 173 |
+
return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
/// Advances to the next tile in memory.
|
| 177 |
+
CUTLASS_HOST_DEVICE
|
| 178 |
+
RegularTileAccessIteratorDirectConv &operator++() {
|
| 179 |
+
++iteration_contiguous_;
|
| 180 |
+
|
| 181 |
+
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
|
| 182 |
+
return *this;
|
| 183 |
+
|
| 184 |
+
// Enter here only if (iteration_contiguous_ ==
|
| 185 |
+
// ThreadMap::Iteration::kContiguous)
|
| 186 |
+
iteration_contiguous_ = 0;
|
| 187 |
+
++iteration_strided_;
|
| 188 |
+
|
| 189 |
+
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
| 190 |
+
return *this;
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
// Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
|
| 194 |
+
// which means we enter the next tile.
|
| 195 |
+
iteration_strided_ = 0;
|
| 196 |
+
|
| 197 |
+
return *this;
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
/// Advances to the next tile in memory.
|
| 201 |
+
CUTLASS_HOST_DEVICE
|
| 202 |
+
RegularTileAccessIteratorDirectConv operator++(int) {
|
| 203 |
+
RegularTileAccessIteratorDirectConv prev(*this);
|
| 204 |
+
this->operator++();
|
| 205 |
+
|
| 206 |
+
return prev;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
/// Adds a tile offset in the unit of tile.
|
| 210 |
+
CUTLASS_DEVICE
|
| 211 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 212 |
+
add_pointer_offset(coord.contiguous() * Shape::kContiguous +
|
| 213 |
+
coord.strided() * ThreadMap::Iterations::kStrided *
|
| 214 |
+
ThreadMap::Delta::kStrided * stride_ * ThreadMap::kElementsPerAccess);
|
| 215 |
+
}
|
| 216 |
+
};
|
| 217 |
+
|
| 218 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 219 |
+
|
| 220 |
+
/// Tile iterator specialized for congruous arrangements for TensorOps with dynamic_iterations ON
|
| 221 |
+
///
|
| 222 |
+
///
|
| 223 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 224 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 225 |
+
/// WriteableContiguousTileIteratorConcept
|
| 226 |
+
///
|
| 227 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 228 |
+
typename ThreadMap_, int Alignment>
|
| 229 |
+
class RegularTileAccessIteratorDirectConv<
|
| 230 |
+
Shape_, Element_,
|
| 231 |
+
layout::PitchLinear,
|
| 232 |
+
AdvanceRank, ThreadMap_,true, Alignment> {
|
| 233 |
+
public:
|
| 234 |
+
static_assert(
|
| 235 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 236 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 237 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 238 |
+
|
| 239 |
+
using Shape = Shape_;
|
| 240 |
+
using Element = Element_;
|
| 241 |
+
using Layout = layout::PitchLinear;
|
| 242 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 243 |
+
static int const kAlignment = Alignment;
|
| 244 |
+
|
| 245 |
+
using Index = typename Layout::Index;
|
| 246 |
+
using LongIndex = typename Layout::LongIndex;
|
| 247 |
+
using StrideIndex = typename Layout::Stride::Index;
|
| 248 |
+
|
| 249 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 250 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 251 |
+
|
| 252 |
+
using ThreadMap = ThreadMap_;
|
| 253 |
+
|
| 254 |
+
/// Element type per access
|
| 255 |
+
using AccessType = Array<Element, ThreadMap::kElementsPerAccess>;
|
| 256 |
+
|
| 257 |
+
private:
|
| 258 |
+
//
|
| 259 |
+
// Data members
|
| 260 |
+
//
|
| 261 |
+
|
| 262 |
+
/// Stride value
|
| 263 |
+
StrideIndex stride_;
|
| 264 |
+
|
| 265 |
+
/// Internal pointer to first access of tile
|
| 266 |
+
AccessType *pointer_;
|
| 267 |
+
|
| 268 |
+
/// Internal byte offset
|
| 269 |
+
Index byte_offset_;
|
| 270 |
+
|
| 271 |
+
/// Iteration in the contiguous dimension
|
| 272 |
+
int iteration_contiguous_;
|
| 273 |
+
|
| 274 |
+
/// Iteration in the strided dimension
|
| 275 |
+
int iteration_strided_;
|
| 276 |
+
|
| 277 |
+
/// Total iterattions in the strided dimension: Dynamic value
|
| 278 |
+
int total_iteration_strided_;
|
| 279 |
+
|
| 280 |
+
public:
|
| 281 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 282 |
+
CUTLASS_HOST_DEVICE
|
| 283 |
+
RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor
|
| 284 |
+
int thread_id ///< ID of each participating thread
|
| 285 |
+
)
|
| 286 |
+
: stride_(ref.stride(0) / ThreadMap::kElementsPerAccess),
|
| 287 |
+
byte_offset_(0) {
|
| 288 |
+
|
| 289 |
+
layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id);
|
| 290 |
+
|
| 291 |
+
// initialize pointer
|
| 292 |
+
pointer_ = reinterpret_cast<AccessType *>(ref.data() + ref.offset(thread_offset_base));
|
| 293 |
+
|
| 294 |
+
set_iteration_index(0);
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
/// Overrides the internal iteration index
|
| 298 |
+
CUTLASS_HOST_DEVICE
|
| 299 |
+
void set_iteration_index(int index) {
|
| 300 |
+
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
| 301 |
+
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
/// Overrides the internal iteration index
|
| 305 |
+
CUTLASS_HOST_DEVICE
|
| 306 |
+
void set_iteration_num(int num) {
|
| 307 |
+
total_iteration_strided_ = num;
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
/// Adds a pointer offset in units of Element
|
| 311 |
+
CUTLASS_HOST_DEVICE
|
| 312 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 313 |
+
byte_offset_ += pointer_offset * sizeof(Element);
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
/// Returns a pointer
|
| 317 |
+
CUTLASS_DEVICE
|
| 318 |
+
AccessType *get() const {
|
| 319 |
+
|
| 320 |
+
AccessType *access_ptr = pointer_;
|
| 321 |
+
|
| 322 |
+
int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ +
|
| 323 |
+
iteration_contiguous_ * ThreadMap::Delta::kContiguous /
|
| 324 |
+
ThreadMap::kElementsPerAccess;
|
| 325 |
+
|
| 326 |
+
char *access_byte_ptr =
|
| 327 |
+
reinterpret_cast<char *>(access_ptr + access_offset);
|
| 328 |
+
|
| 329 |
+
return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
/// Advances to the next tile in memory.
|
| 333 |
+
CUTLASS_HOST_DEVICE
|
| 334 |
+
RegularTileAccessIteratorDirectConv &operator++() {
|
| 335 |
+
++iteration_contiguous_;
|
| 336 |
+
|
| 337 |
+
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
|
| 338 |
+
return *this;
|
| 339 |
+
|
| 340 |
+
// Enter here only if (iteration_contiguous_ ==
|
| 341 |
+
// ThreadMap::Iteration::kContiguous)
|
| 342 |
+
iteration_contiguous_ = 0;
|
| 343 |
+
++iteration_strided_;
|
| 344 |
+
|
| 345 |
+
if (iteration_strided_ < total_iteration_strided_) {
|
| 346 |
+
return *this;
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
// Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
|
| 350 |
+
// which means we enter the next tile.
|
| 351 |
+
iteration_strided_ = 0;
|
| 352 |
+
|
| 353 |
+
return *this;
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
/// Advances to the next tile in memory.
|
| 357 |
+
CUTLASS_HOST_DEVICE
|
| 358 |
+
RegularTileAccessIteratorDirectConv operator++(int) {
|
| 359 |
+
RegularTileAccessIteratorDirectConv prev(*this);
|
| 360 |
+
this->operator++();
|
| 361 |
+
|
| 362 |
+
return prev;
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
/// Adds a tile offset in the unit of tile.
|
| 366 |
+
CUTLASS_DEVICE
|
| 367 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 368 |
+
add_pointer_offset(coord.contiguous() * Shape::kContiguous +
|
| 369 |
+
coord.strided() * total_iteration_strided_ * ThreadMap::Delta::kStrided * stride_ *
|
| 370 |
+
ThreadMap::kElementsPerAccess);
|
| 371 |
+
}
|
| 372 |
+
};
|
| 373 |
+
|
| 374 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 375 |
+
|
| 376 |
+
/// Tile iterator specialized for column major layouts
|
| 377 |
+
///
|
| 378 |
+
///
|
| 379 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 380 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 381 |
+
/// WriteableContiguousTileIteratorConcept
|
| 382 |
+
///
|
| 383 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 384 |
+
typename ThreadMap_,bool Dynamic_iterations, int Alignment >
|
| 385 |
+
class RegularTileAccessIteratorDirectConv<
|
| 386 |
+
Shape_, Element_,
|
| 387 |
+
layout::ColumnMajor,
|
| 388 |
+
AdvanceRank, ThreadMap_, Dynamic_iterations , Alignment> {
|
| 389 |
+
public:
|
| 390 |
+
static_assert(
|
| 391 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 392 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 393 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 394 |
+
|
| 395 |
+
using Shape = Shape_;
|
| 396 |
+
using Element = Element_;
|
| 397 |
+
using Layout = layout::ColumnMajor;
|
| 398 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 399 |
+
static int const kAlignment = Alignment;
|
| 400 |
+
|
| 401 |
+
using Index = typename Layout::Index;
|
| 402 |
+
using LongIndex = typename Layout::LongIndex;
|
| 403 |
+
|
| 404 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 405 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 406 |
+
|
| 407 |
+
using ThreadMap = ThreadMap_;
|
| 408 |
+
|
| 409 |
+
/// Underlying iterator type
|
| 410 |
+
using UnderlyingIterator = RegularTileAccessIteratorDirectConv<
|
| 411 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
|
| 412 |
+
layout::PitchLinear,
|
| 413 |
+
(kAdvanceRank == 0 ? 0 : 1),
|
| 414 |
+
ThreadMap_,
|
| 415 |
+
Dynamic_iterations>;
|
| 416 |
+
|
| 417 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 418 |
+
|
| 419 |
+
private:
|
| 420 |
+
|
| 421 |
+
/// Underlying iterator
|
| 422 |
+
UnderlyingIterator iterator_;
|
| 423 |
+
|
| 424 |
+
public:
|
| 425 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 426 |
+
CUTLASS_HOST_DEVICE
|
| 427 |
+
RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor
|
| 428 |
+
int thread_id ///< ID of each participating thread
|
| 429 |
+
)
|
| 430 |
+
: iterator_({ref.data(), ref.stride()}, thread_id) {}
|
| 431 |
+
|
| 432 |
+
/// Overrides the internal iteration index
|
| 433 |
+
CUTLASS_HOST_DEVICE
|
| 434 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 435 |
+
|
| 436 |
+
/// Overrides the internal iteration index
|
| 437 |
+
CUTLASS_HOST_DEVICE
|
| 438 |
+
void set_iteration_num(int num) {
|
| 439 |
+
iterator_.set_iteration_num(num);
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
/// Adds a pointer offset in units of Element
|
| 443 |
+
CUTLASS_HOST_DEVICE
|
| 444 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 445 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
/// Returns a pointer
|
| 449 |
+
CUTLASS_HOST_DEVICE
|
| 450 |
+
AccessType *get() const {
|
| 451 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
/// Adds a tile offset
|
| 455 |
+
CUTLASS_DEVICE
|
| 456 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 457 |
+
iterator_.add_tile_offset({coord.row(), coord.column()});
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
/// Advances to the next tile in memory.
|
| 461 |
+
CUTLASS_HOST_DEVICE
|
| 462 |
+
RegularTileAccessIteratorDirectConv &operator++() {
|
| 463 |
+
++iterator_;
|
| 464 |
+
return *this;
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
/// Advances to the next tile in memory.
|
| 468 |
+
CUTLASS_HOST_DEVICE
|
| 469 |
+
RegularTileAccessIteratorDirectConv operator++(int) {
|
| 470 |
+
RegularTileAccessIteratorDirectConv prev(*this);
|
| 471 |
+
++iterator_;
|
| 472 |
+
|
| 473 |
+
return prev;
|
| 474 |
+
}
|
| 475 |
+
};
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 479 |
+
|
| 480 |
+
/// Tile iterator specialized for row major layouts
|
| 481 |
+
///
|
| 482 |
+
///
|
| 483 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 484 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 485 |
+
/// WriteableContiguousTileIteratorConcept
|
| 486 |
+
///
|
| 487 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 488 |
+
typename ThreadMap_,bool Dynamic_iterations, int Alignment>
|
| 489 |
+
class RegularTileAccessIteratorDirectConv<
|
| 490 |
+
Shape_, Element_,
|
| 491 |
+
layout::RowMajor,
|
| 492 |
+
AdvanceRank, ThreadMap_, Dynamic_iterations, Alignment> {
|
| 493 |
+
public:
|
| 494 |
+
static_assert(
|
| 495 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 496 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 497 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 498 |
+
|
| 499 |
+
using Shape = Shape_;
|
| 500 |
+
using Element = Element_;
|
| 501 |
+
using Layout = layout::RowMajor;
|
| 502 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 503 |
+
static int const kAlignment = Alignment;
|
| 504 |
+
|
| 505 |
+
using Index = typename Layout::Index;
|
| 506 |
+
using LongIndex = typename Layout::LongIndex;
|
| 507 |
+
|
| 508 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 509 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 510 |
+
|
| 511 |
+
using ThreadMap = ThreadMap_;
|
| 512 |
+
|
| 513 |
+
/// Underlying iterator type
|
| 514 |
+
using UnderlyingIterator = RegularTileAccessIteratorDirectConv<
|
| 515 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
|
| 516 |
+
layout::PitchLinear,
|
| 517 |
+
(kAdvanceRank == 0 ? 1 : 0),
|
| 518 |
+
ThreadMap_,
|
| 519 |
+
Dynamic_iterations>;
|
| 520 |
+
|
| 521 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 522 |
+
|
| 523 |
+
private:
|
| 524 |
+
|
| 525 |
+
/// Underlying iterator
|
| 526 |
+
UnderlyingIterator iterator_;
|
| 527 |
+
|
| 528 |
+
public:
|
| 529 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 530 |
+
CUTLASS_HOST_DEVICE
|
| 531 |
+
RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor
|
| 532 |
+
int thread_id ///< ID of each participating thread
|
| 533 |
+
)
|
| 534 |
+
: iterator_({ref.data(), ref.stride()}, thread_id) {}
|
| 535 |
+
|
| 536 |
+
/// Overrides the internal iteration index
|
| 537 |
+
CUTLASS_HOST_DEVICE
|
| 538 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 539 |
+
|
| 540 |
+
/// Overrides the internal iteration index
|
| 541 |
+
CUTLASS_HOST_DEVICE
|
| 542 |
+
void set_iteration_num(int num) {
|
| 543 |
+
iterator_.set_iteration_num(num);
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
/// Adds a pointer offset in units of Element
|
| 547 |
+
CUTLASS_HOST_DEVICE
|
| 548 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 549 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 550 |
+
}
|
| 551 |
+
|
| 552 |
+
/// Returns a pointer
|
| 553 |
+
CUTLASS_HOST_DEVICE
|
| 554 |
+
AccessType *get() const {
|
| 555 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
/// Adds a tile offset
|
| 559 |
+
CUTLASS_DEVICE
|
| 560 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 561 |
+
iterator_.add_tile_offset({coord.column(), coord.row()});
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
/// Advances to the next tile in memory.
|
| 565 |
+
CUTLASS_HOST_DEVICE
|
| 566 |
+
RegularTileAccessIteratorDirectConv &operator++() {
|
| 567 |
+
++iterator_;
|
| 568 |
+
return *this;
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
/// Advances to the next tile in memory.
|
| 572 |
+
CUTLASS_HOST_DEVICE
|
| 573 |
+
RegularTileAccessIteratorDirectConv operator++(int) {
|
| 574 |
+
RegularTileAccessIteratorDirectConv prev(*this);
|
| 575 |
+
++iterator_;
|
| 576 |
+
|
| 577 |
+
return prev;
|
| 578 |
+
}
|
| 579 |
+
};
|
| 580 |
+
|
| 581 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 582 |
+
|
| 583 |
+
} // namespace threadblock
|
| 584 |
+
} // namespace transform
|
| 585 |
+
} // namespace cutlass
|
| 586 |
+
|
| 587 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h
ADDED
|
@@ -0,0 +1,821 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates implementing computing the addresses of storing of tiles
|
| 33 |
+
from pitch-linear rank=2 tensors.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/array.h"
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
#include "cutlass/layout/pitch_linear.h"
|
| 41 |
+
#include "cutlass/layout/tensor_op_multiplicand_sm75.h"
|
| 42 |
+
#include "cutlass/matrix_coord.h"
|
| 43 |
+
#include "cutlass/matrix_shape.h"
|
| 44 |
+
#include "cutlass/tensor_ref.h"
|
| 45 |
+
#include "cutlass/transform/threadblock/regular_tile_access_iterator.h"
|
| 46 |
+
|
| 47 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 48 |
+
|
| 49 |
+
namespace cutlass {
|
| 50 |
+
namespace transform {
|
| 51 |
+
namespace threadblock {
|
| 52 |
+
|
| 53 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 54 |
+
|
| 55 |
+
/// Tile iterator specialized for congruous arrangements for TensorOps
|
| 56 |
+
///
|
| 57 |
+
///
|
| 58 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 59 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 60 |
+
/// WriteableContiguousTileIteratorConcept
|
| 61 |
+
///
|
| 62 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 63 |
+
typename ThreadMap_, int Alignment, int Crosswise>
|
| 64 |
+
class RegularTileAccessIterator<
|
| 65 |
+
Shape_, Element_,
|
| 66 |
+
layout::TensorOpMultiplicandCongruous<sizeof_bits<Element_>::value,
|
| 67 |
+
Crosswise>,
|
| 68 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 69 |
+
public:
|
| 70 |
+
static_assert(
|
| 71 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 72 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 73 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 74 |
+
|
| 75 |
+
using Shape = Shape_;
|
| 76 |
+
using Element = Element_;
|
| 77 |
+
using Layout =
|
| 78 |
+
layout::TensorOpMultiplicandCongruous<sizeof_bits<Element_>::value,
|
| 79 |
+
Crosswise>;
|
| 80 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 81 |
+
static int const kAlignment = Alignment;
|
| 82 |
+
static int const kCrosswise = Crosswise;
|
| 83 |
+
|
| 84 |
+
using Index = typename Layout::Index;
|
| 85 |
+
using LongIndex = typename Layout::LongIndex;
|
| 86 |
+
using StrideIndex = typename Layout::Stride::Index;
|
| 87 |
+
|
| 88 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 89 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 90 |
+
|
| 91 |
+
using ThreadMap = ThreadMap_;
|
| 92 |
+
|
| 93 |
+
/// Internal details made public to facilitate introspection
|
| 94 |
+
struct Detail {
|
| 95 |
+
/// This iterator is specialized for an access size that is 128 bits in
|
| 96 |
+
/// length.
|
| 97 |
+
static int const kAccessSizeInBits = 128;
|
| 98 |
+
|
| 99 |
+
static_assert(sizeof_bits<Element_>::value *
|
| 100 |
+
ThreadMap::kElementsPerAccess ==
|
| 101 |
+
kAccessSizeInBits,
|
| 102 |
+
"This iterator requires a policy whose access size is 128bs");
|
| 103 |
+
|
| 104 |
+
///< Number of pointers
|
| 105 |
+
static int const kPointerCount =
|
| 106 |
+
(ThreadMap::Iterations::kStrided > 1 ? 2 : 1);
|
| 107 |
+
};
|
| 108 |
+
|
| 109 |
+
/// Element type per access
|
| 110 |
+
using AccessType = Array<Element, Layout::kElementsPerAccess>;
|
| 111 |
+
|
| 112 |
+
private:
|
| 113 |
+
//
|
| 114 |
+
// Data members
|
| 115 |
+
//
|
| 116 |
+
|
| 117 |
+
/// Stride value
|
| 118 |
+
StrideIndex stride_;
|
| 119 |
+
|
| 120 |
+
/// Internal pointer to first access of tile
|
| 121 |
+
AccessType *pointer_[Detail::kPointerCount];
|
| 122 |
+
|
| 123 |
+
/// Internal byte offset
|
| 124 |
+
Index byte_offset_;
|
| 125 |
+
|
| 126 |
+
/// Iteration in the contiguous dimension
|
| 127 |
+
int iteration_contiguous_;
|
| 128 |
+
|
| 129 |
+
/// Iteration in the strided dimension
|
| 130 |
+
int iteration_strided_;
|
| 131 |
+
|
| 132 |
+
public:
|
| 133 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 134 |
+
CUTLASS_HOST_DEVICE
|
| 135 |
+
RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
|
| 136 |
+
int thread_id ///< ID of each participating thread
|
| 137 |
+
)
|
| 138 |
+
: stride_(ref.stride(0) * Layout::kFactor / Layout::kElementsPerAccess),
|
| 139 |
+
byte_offset_(0) {
|
| 140 |
+
layout::PitchLinearCoord thread_offset_base =
|
| 141 |
+
ThreadMap::initial_offset(thread_id);
|
| 142 |
+
|
| 143 |
+
CUTLASS_PRAGMA_UNROLL
|
| 144 |
+
for (int i = 0; i < Detail::kPointerCount; ++i) {
|
| 145 |
+
// This is the offset of a thread within a threadblock tile for a specific
|
| 146 |
+
// pointer (units of elements)
|
| 147 |
+
layout::PitchLinearCoord thread_offset_in_threadblock_tile =
|
| 148 |
+
thread_offset_base +
|
| 149 |
+
layout::PitchLinearCoord{
|
| 150 |
+
0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i};
|
| 151 |
+
|
| 152 |
+
// initialize pointer
|
| 153 |
+
pointer_[i] = reinterpret_cast<AccessType *>(
|
| 154 |
+
ref.data() + ref.offset(thread_offset_in_threadblock_tile));
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
set_iteration_index(0);
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
/// Overrides the internal iteration index
|
| 161 |
+
CUTLASS_HOST_DEVICE
|
| 162 |
+
void set_iteration_index(int index) {
|
| 163 |
+
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
| 164 |
+
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
/// Adds a pointer offset in units of Element
|
| 168 |
+
CUTLASS_HOST_DEVICE
|
| 169 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 170 |
+
byte_offset_ += pointer_offset * sizeof(Element);
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
/// Returns a pointer
|
| 174 |
+
CUTLASS_HOST_DEVICE
|
| 175 |
+
AccessType *get() const {
|
| 176 |
+
AccessType *access_ptr = pointer_[iteration_strided_ & 1];
|
| 177 |
+
int stride_idx = (iteration_strided_ & ~1);
|
| 178 |
+
|
| 179 |
+
int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ / Layout::kFactor +
|
| 180 |
+
iteration_contiguous_ * ThreadMap::Delta::kContiguous /
|
| 181 |
+
ThreadMap::kElementsPerAccess;
|
| 182 |
+
|
| 183 |
+
char *access_byte_ptr =
|
| 184 |
+
reinterpret_cast<char *>(access_ptr + access_offset);
|
| 185 |
+
return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
/// Advances to the next tile in memory.
|
| 189 |
+
CUTLASS_HOST_DEVICE
|
| 190 |
+
RegularTileAccessIterator &operator++() {
|
| 191 |
+
++iteration_contiguous_;
|
| 192 |
+
|
| 193 |
+
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
|
| 194 |
+
return *this;
|
| 195 |
+
|
| 196 |
+
// Enter here only if (iteration_contiguous_ ==
|
| 197 |
+
// ThreadMap::Iteration::kContiguous)
|
| 198 |
+
iteration_contiguous_ = 0;
|
| 199 |
+
++iteration_strided_;
|
| 200 |
+
|
| 201 |
+
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
| 202 |
+
return *this;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
// Enter here only if (iteration_strided_ == ThreadMap::Iteration::kStrided)
|
| 206 |
+
// which means we enter the next tile.
|
| 207 |
+
iteration_strided_ = 0;
|
| 208 |
+
|
| 209 |
+
return *this;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
/// Advances to the next tile in memory.
|
| 213 |
+
CUTLASS_HOST_DEVICE
|
| 214 |
+
RegularTileAccessIterator operator++(int) {
|
| 215 |
+
RegularTileAccessIterator prev(*this);
|
| 216 |
+
this->operator++();
|
| 217 |
+
|
| 218 |
+
return prev;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
/// Adds a tile offset
|
| 222 |
+
CUTLASS_DEVICE
|
| 223 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 224 |
+
add_pointer_offset(coord.contiguous() * Shape::kContiguous * Layout::kFactor +
|
| 225 |
+
coord.strided() * Shape::kStrided * stride_ *
|
| 226 |
+
Layout::kElementsPerAccess / Layout::kFactor);
|
| 227 |
+
}
|
| 228 |
+
};
|
| 229 |
+
|
| 230 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 231 |
+
|
| 232 |
+
/// Tile Iterator specialized for column-major congruous TensorOp formats.
|
| 233 |
+
///
|
| 234 |
+
///
|
| 235 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 236 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 237 |
+
/// WriteableContiguousTileIteratorConcept
|
| 238 |
+
///
|
| 239 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 240 |
+
typename ThreadMap_, int Alignment, int Crosswise>
|
| 241 |
+
class RegularTileAccessIterator<
|
| 242 |
+
Shape_, Element_,
|
| 243 |
+
layout::ColumnMajorTensorOpMultiplicandCongruous<
|
| 244 |
+
sizeof_bits<Element_>::value, Crosswise>,
|
| 245 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 246 |
+
public:
|
| 247 |
+
static_assert(
|
| 248 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 249 |
+
"Specialization for column-major iterator may along advance along the "
|
| 250 |
+
"columns(rank=0) or rows(rank=1) dimension.");
|
| 251 |
+
|
| 252 |
+
using Shape = Shape_;
|
| 253 |
+
using Element = Element_;
|
| 254 |
+
using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous<
|
| 255 |
+
sizeof_bits<Element_>::value, Crosswise>;
|
| 256 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 257 |
+
static int const kAlignment = Alignment;
|
| 258 |
+
|
| 259 |
+
using Index = typename Layout::Index;
|
| 260 |
+
using LongIndex = typename Layout::LongIndex;
|
| 261 |
+
|
| 262 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 263 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 264 |
+
|
| 265 |
+
using ThreadMap = ThreadMap_;
|
| 266 |
+
|
| 267 |
+
/// Underlying iterator type
|
| 268 |
+
using UnderlyingIterator = RegularTileAccessIterator<
|
| 269 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
|
| 270 |
+
layout::TensorOpMultiplicandCongruous<sizeof_bits<Element_>::value,
|
| 271 |
+
Crosswise>,
|
| 272 |
+
(kAdvanceRank == 0 ? 0 : 1), ThreadMap_>;
|
| 273 |
+
|
| 274 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 275 |
+
|
| 276 |
+
private:
|
| 277 |
+
/// Underlying iterator
|
| 278 |
+
UnderlyingIterator iterator_;
|
| 279 |
+
|
| 280 |
+
public:
|
| 281 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 282 |
+
CUTLASS_HOST_DEVICE
|
| 283 |
+
RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
|
| 284 |
+
int thread_id ///< ID of each participating thread
|
| 285 |
+
)
|
| 286 |
+
: iterator_({ref.data(), ref.stride()}, thread_id) {}
|
| 287 |
+
|
| 288 |
+
/// Overrides the internal iteration index
|
| 289 |
+
CUTLASS_HOST_DEVICE
|
| 290 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 291 |
+
|
| 292 |
+
/// Adds a pointer offset in units of Element
|
| 293 |
+
CUTLASS_HOST_DEVICE
|
| 294 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 295 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
/// Returns a pointer
|
| 299 |
+
CUTLASS_HOST_DEVICE
|
| 300 |
+
AccessType *get() const {
|
| 301 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
/// Adds a tile offset
|
| 305 |
+
CUTLASS_DEVICE
|
| 306 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 307 |
+
iterator_.add_tile_offset({coord.row(), coord.column()});
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
/// Advances to the next tile in memory.
|
| 311 |
+
CUTLASS_HOST_DEVICE
|
| 312 |
+
RegularTileAccessIterator &operator++() {
|
| 313 |
+
++iterator_;
|
| 314 |
+
return *this;
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
/// Advances to the next tile in memory.
|
| 318 |
+
CUTLASS_HOST_DEVICE
|
| 319 |
+
RegularTileAccessIterator operator++(int) {
|
| 320 |
+
RegularTileAccessIterator prev(*this);
|
| 321 |
+
++iterator_;
|
| 322 |
+
|
| 323 |
+
return prev;
|
| 324 |
+
}
|
| 325 |
+
};
|
| 326 |
+
|
| 327 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 328 |
+
|
| 329 |
+
/// Tile Iterator specialized for row-major congruous TensorOp formats.
|
| 330 |
+
///
|
| 331 |
+
///
|
| 332 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 333 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 334 |
+
/// WriteableContiguousTileIteratorConcept
|
| 335 |
+
///
|
| 336 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 337 |
+
typename ThreadMap_, int Alignment, int Crosswise>
|
| 338 |
+
class RegularTileAccessIterator<
|
| 339 |
+
Shape_, Element_,
|
| 340 |
+
layout::RowMajorTensorOpMultiplicandCongruous<sizeof_bits<Element_>::value,
|
| 341 |
+
Crosswise>,
|
| 342 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 343 |
+
public:
|
| 344 |
+
static_assert(
|
| 345 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 346 |
+
"Specialization for row-major iterator may along advance along the "
|
| 347 |
+
"columns(rank=0) or rows(rank=1) dimension.");
|
| 348 |
+
|
| 349 |
+
using Shape = Shape_;
|
| 350 |
+
using Element = Element_;
|
| 351 |
+
using Layout = layout::RowMajorTensorOpMultiplicandCongruous<
|
| 352 |
+
sizeof_bits<Element_>::value, Crosswise>;
|
| 353 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 354 |
+
static int const kAlignment = Alignment;
|
| 355 |
+
|
| 356 |
+
using Index = typename Layout::Index;
|
| 357 |
+
using LongIndex = typename Layout::LongIndex;
|
| 358 |
+
|
| 359 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 360 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 361 |
+
|
| 362 |
+
using ThreadMap = ThreadMap_;
|
| 363 |
+
|
| 364 |
+
/// Underlying iterator type
|
| 365 |
+
using UnderlyingIterator = RegularTileAccessIterator<
|
| 366 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
|
| 367 |
+
layout::TensorOpMultiplicandCongruous<sizeof_bits<Element_>::value,
|
| 368 |
+
Crosswise>,
|
| 369 |
+
(kAdvanceRank == 0 ? 1 : 0), ThreadMap_>;
|
| 370 |
+
|
| 371 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 372 |
+
|
| 373 |
+
private:
|
| 374 |
+
/// Underlying iterator
|
| 375 |
+
UnderlyingIterator iterator_;
|
| 376 |
+
|
| 377 |
+
public:
|
| 378 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 379 |
+
CUTLASS_HOST_DEVICE
|
| 380 |
+
RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
|
| 381 |
+
int thread_id ///< ID of each participating thread
|
| 382 |
+
)
|
| 383 |
+
: iterator_({ref.data(), ref.stride()}, thread_id) {}
|
| 384 |
+
|
| 385 |
+
/// Overrides the internal iteration index
|
| 386 |
+
CUTLASS_HOST_DEVICE
|
| 387 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 388 |
+
|
| 389 |
+
/// Adds a pointer offset in units of Element
|
| 390 |
+
CUTLASS_HOST_DEVICE
|
| 391 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 392 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
/// Returns a pointer
|
| 396 |
+
CUTLASS_HOST_DEVICE
|
| 397 |
+
AccessType *get() const {
|
| 398 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
/// Adds a tile offset
|
| 402 |
+
CUTLASS_DEVICE
|
| 403 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 404 |
+
iterator_.add_tile_offset({coord.column(), coord.row()});
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
/// Advances to the next tile in memory.
|
| 408 |
+
CUTLASS_HOST_DEVICE
|
| 409 |
+
RegularTileAccessIterator &operator++() {
|
| 410 |
+
++iterator_;
|
| 411 |
+
return *this;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
/// Advances to the next tile in memory.
|
| 415 |
+
CUTLASS_HOST_DEVICE
|
| 416 |
+
RegularTileAccessIterator operator++(int) {
|
| 417 |
+
RegularTileAccessIterator prev(*this);
|
| 418 |
+
++iterator_;
|
| 419 |
+
|
| 420 |
+
return prev;
|
| 421 |
+
}
|
| 422 |
+
};
|
| 423 |
+
|
| 424 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 425 |
+
|
| 426 |
+
/// Tile iterator specialized for crosswise arrangements for TensorOps
|
| 427 |
+
///
|
| 428 |
+
///
|
| 429 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 430 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 431 |
+
/// WriteableContiguousTileIteratorConcept
|
| 432 |
+
///
|
| 433 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 434 |
+
typename ThreadMap_, int Alignment, int Crosswise>
|
| 435 |
+
class RegularTileAccessIterator<Shape_, Element_,
|
| 436 |
+
layout::TensorOpMultiplicandCrosswise<
|
| 437 |
+
sizeof_bits<Element_>::value, Crosswise>,
|
| 438 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 439 |
+
public:
|
| 440 |
+
static_assert(
|
| 441 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 442 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 443 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 444 |
+
|
| 445 |
+
using Shape = Shape_;
|
| 446 |
+
using Element = Element_;
|
| 447 |
+
using Layout =
|
| 448 |
+
layout::TensorOpMultiplicandCrosswise<sizeof_bits<Element_>::value,
|
| 449 |
+
Crosswise>;
|
| 450 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 451 |
+
static int const kAlignment = Alignment;
|
| 452 |
+
static int const kCrosswise = Crosswise;
|
| 453 |
+
|
| 454 |
+
using Index = typename Layout::Index;
|
| 455 |
+
using LongIndex = typename Layout::LongIndex;
|
| 456 |
+
using StrideIndex = typename Layout::Stride::Index;
|
| 457 |
+
|
| 458 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 459 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 460 |
+
|
| 461 |
+
using ThreadMap = ThreadMap_;
|
| 462 |
+
|
| 463 |
+
static_assert(!(ThreadMap::Delta::kContiguous % kCrosswise),
|
| 464 |
+
"kCrosswise is the smallest unit in the contiguous dimension "
|
| 465 |
+
"for shared memory swizzling.");
|
| 466 |
+
|
| 467 |
+
/// Internal details made public to facilitate introspection
|
| 468 |
+
struct Detail {
|
| 469 |
+
/// This iterator is specialized for an access size that is 128 bits in
|
| 470 |
+
/// length.
|
| 471 |
+
static int const kAccessSizeInBits = 128;
|
| 472 |
+
|
| 473 |
+
static_assert(sizeof_bits<Element_>::value *
|
| 474 |
+
ThreadMap::kElementsPerAccess ==
|
| 475 |
+
kAccessSizeInBits,
|
| 476 |
+
"This iterator requires a policy whose access size is 128bs");
|
| 477 |
+
|
| 478 |
+
/// Number of pointers
|
| 479 |
+
///
|
| 480 |
+
/// Note:TN kblock32 layouts only needs 1 pointer, but strangely
|
| 481 |
+
/// reducing pointer count hurts perfomrnace
|
| 482 |
+
static int const kPointerCount =
|
| 483 |
+
(ThreadMap::Iterations::kStrided > 1 ? 2 : 1);
|
| 484 |
+
};
|
| 485 |
+
|
| 486 |
+
/// Element type per access
|
| 487 |
+
using AccessType = Array<Element, Layout::kElementsPerAccess>;
|
| 488 |
+
|
| 489 |
+
private:
|
| 490 |
+
//
|
| 491 |
+
// Data members
|
| 492 |
+
//
|
| 493 |
+
|
| 494 |
+
/// Total number of sections. The memory is divided into stages. One stage
|
| 495 |
+
/// can store one tile. Stage is divided into sections. Interleaved layout
|
| 496 |
+
/// can have multiple sections in a stage. The rest layout only has one section
|
| 497 |
+
/// in a stage.
|
| 498 |
+
int sections_;
|
| 499 |
+
|
| 500 |
+
/// Sections that a stage has
|
| 501 |
+
int sections_per_stage_;
|
| 502 |
+
|
| 503 |
+
/// Stride value
|
| 504 |
+
StrideIndex stride_;
|
| 505 |
+
|
| 506 |
+
/// Internal pointer to first access of tile
|
| 507 |
+
AccessType *pointer_[Detail::kPointerCount];
|
| 508 |
+
|
| 509 |
+
/// Internal byte offset
|
| 510 |
+
Index byte_offset_;
|
| 511 |
+
|
| 512 |
+
/// Iteration in the contiguous dimension
|
| 513 |
+
int iteration_contiguous_;
|
| 514 |
+
|
| 515 |
+
/// Iteration in the strided dimension
|
| 516 |
+
int iteration_strided_;
|
| 517 |
+
|
| 518 |
+
public:
|
| 519 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 520 |
+
CUTLASS_HOST_DEVICE
|
| 521 |
+
RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
|
| 522 |
+
int thread_id ///< ID of each participating thread
|
| 523 |
+
)
|
| 524 |
+
: sections_(ref.stride(0) / kCrosswise),
|
| 525 |
+
sections_per_stage_(Shape::kContiguous / kCrosswise),
|
| 526 |
+
// stride_ = kCrosswise x sections_ x kFactor
|
| 527 |
+
stride_(ref.stride(0) * Layout::kFactor / Layout::kElementsPerAccess),
|
| 528 |
+
byte_offset_(0) {
|
| 529 |
+
layout::PitchLinearCoord thread_offset_base =
|
| 530 |
+
ThreadMap::initial_offset(thread_id);
|
| 531 |
+
|
| 532 |
+
CUTLASS_PRAGMA_UNROLL
|
| 533 |
+
for (int i = 0; i < Detail::kPointerCount; ++i) {
|
| 534 |
+
// This is the offset of a thread within a threadblock tile for a specific
|
| 535 |
+
// pointer (units of elements)
|
| 536 |
+
layout::PitchLinearCoord thread_offset_in_threadblock_tile =
|
| 537 |
+
thread_offset_base +
|
| 538 |
+
layout::PitchLinearCoord{
|
| 539 |
+
0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i};
|
| 540 |
+
// initialize pointer
|
| 541 |
+
pointer_[i] = reinterpret_cast<AccessType *>(ref.data()) +
|
| 542 |
+
ref.offset(thread_offset_in_threadblock_tile) /
|
| 543 |
+
Layout::kElementsPerAccess;
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
set_iteration_index(0);
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
/// Overrides the internal iteration index
|
| 550 |
+
CUTLASS_HOST_DEVICE
|
| 551 |
+
void set_iteration_index(int index) {
|
| 552 |
+
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
| 553 |
+
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
/// Adds a pointer offset in units of Element
|
| 557 |
+
CUTLASS_HOST_DEVICE
|
| 558 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 559 |
+
byte_offset_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
/// Returns a pointer
|
| 563 |
+
CUTLASS_HOST_DEVICE
|
| 564 |
+
AccessType *get() const {
|
| 565 |
+
AccessType *access_ptr = pointer_[iteration_strided_ & 1];
|
| 566 |
+
int stride_idx = (iteration_strided_ & ~1);
|
| 567 |
+
|
| 568 |
+
int access_offset =
|
| 569 |
+
stride_idx * ThreadMap::Delta::kStrided * stride_ / Layout::kFactor +
|
| 570 |
+
// kCrosswise elements in the contiguous dimension would span to a
|
| 571 |
+
// shared memory cache line.
|
| 572 |
+
iteration_contiguous_ * (ThreadMap::Delta::kContiguous / kCrosswise) *
|
| 573 |
+
Layout::TileShape::kContiguous;
|
| 574 |
+
char *access_byte_ptr =
|
| 575 |
+
reinterpret_cast<char *>(access_ptr + access_offset);
|
| 576 |
+
return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
/// Advances to the next tile in memory.
|
| 580 |
+
CUTLASS_HOST_DEVICE
|
| 581 |
+
RegularTileAccessIterator &operator++() {
|
| 582 |
+
++iteration_contiguous_;
|
| 583 |
+
|
| 584 |
+
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
|
| 585 |
+
return *this;
|
| 586 |
+
|
| 587 |
+
// Enter here only if (iteration_contiguous_ ==
|
| 588 |
+
// ThreadMap::Iteration::kContiguous)
|
| 589 |
+
iteration_contiguous_ = 0;
|
| 590 |
+
++iteration_strided_;
|
| 591 |
+
|
| 592 |
+
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
| 593 |
+
return *this;
|
| 594 |
+
}
|
| 595 |
+
|
| 596 |
+
// Enter here only if (iteration_strided_ == ThreadMap::Iteration::kStrided)
|
| 597 |
+
// which means we enter the next section.
|
| 598 |
+
iteration_strided_ = 0;
|
| 599 |
+
|
| 600 |
+
return *this;
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
/// Advances to the next tile in memory.
|
| 604 |
+
CUTLASS_HOST_DEVICE
|
| 605 |
+
RegularTileAccessIterator operator++(int) {
|
| 606 |
+
RegularTileAccessIterator prev(*this);
|
| 607 |
+
this->operator++();
|
| 608 |
+
|
| 609 |
+
return prev;
|
| 610 |
+
}
|
| 611 |
+
|
| 612 |
+
/// Adds a tile offset
|
| 613 |
+
CUTLASS_DEVICE
|
| 614 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 615 |
+
add_pointer_offset(coord.contiguous() * sections_per_stage_ * stride_ *
|
| 616 |
+
ThreadMap::kElementsPerAccess / sections_ +
|
| 617 |
+
coord.strided() * Shape::kStrided * stride_ *
|
| 618 |
+
Layout::kElementsPerAccess / Layout::kFactor);
|
| 619 |
+
}
|
| 620 |
+
};
|
| 621 |
+
|
| 622 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 623 |
+
|
| 624 |
+
/// Tile Iterator specialized for column-major crosswise TensorOp formats.
|
| 625 |
+
///
|
| 626 |
+
///
|
| 627 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 628 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 629 |
+
/// WriteableContiguousTileIteratorConcept
|
| 630 |
+
///
|
| 631 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 632 |
+
typename ThreadMap_, int Alignment, int Crosswise>
|
| 633 |
+
class RegularTileAccessIterator<
|
| 634 |
+
Shape_, Element_,
|
| 635 |
+
layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
| 636 |
+
sizeof_bits<Element_>::value, Crosswise>,
|
| 637 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 638 |
+
public:
|
| 639 |
+
static_assert(
|
| 640 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 641 |
+
"Specialization for column-major iterator may along advance along the "
|
| 642 |
+
"columns(rank=0) or rows(rank=1) dimension.");
|
| 643 |
+
|
| 644 |
+
using Shape = Shape_;
|
| 645 |
+
using Element = Element_;
|
| 646 |
+
using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
| 647 |
+
sizeof_bits<Element_>::value, Crosswise>;
|
| 648 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 649 |
+
static int const kAlignment = Alignment;
|
| 650 |
+
|
| 651 |
+
using Index = typename Layout::Index;
|
| 652 |
+
using LongIndex = typename Layout::LongIndex;
|
| 653 |
+
|
| 654 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 655 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 656 |
+
|
| 657 |
+
using ThreadMap = ThreadMap_;
|
| 658 |
+
|
| 659 |
+
/// Underlying iterator type
|
| 660 |
+
using UnderlyingIterator = RegularTileAccessIterator<
|
| 661 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
|
| 662 |
+
layout::TensorOpMultiplicandCrosswise<sizeof_bits<Element_>::value,
|
| 663 |
+
Crosswise>,
|
| 664 |
+
(kAdvanceRank == 0 ? 0 : 1), ThreadMap_>;
|
| 665 |
+
|
| 666 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 667 |
+
|
| 668 |
+
private:
|
| 669 |
+
/// Underlying iterator
|
| 670 |
+
UnderlyingIterator iterator_;
|
| 671 |
+
|
| 672 |
+
public:
|
| 673 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 674 |
+
CUTLASS_HOST_DEVICE
|
| 675 |
+
RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
|
| 676 |
+
int thread_id ///< ID of each participating thread
|
| 677 |
+
)
|
| 678 |
+
: iterator_({ref.data(), ref.stride()}, thread_id) {}
|
| 679 |
+
|
| 680 |
+
/// Overrides the internal iteration index
|
| 681 |
+
CUTLASS_HOST_DEVICE
|
| 682 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 683 |
+
|
| 684 |
+
/// Adds a pointer offset in units of Element
|
| 685 |
+
CUTLASS_HOST_DEVICE
|
| 686 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 687 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 688 |
+
}
|
| 689 |
+
|
| 690 |
+
/// Returns a pointer
|
| 691 |
+
CUTLASS_HOST_DEVICE
|
| 692 |
+
AccessType *get() const {
|
| 693 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 694 |
+
}
|
| 695 |
+
|
| 696 |
+
/// Adds a tile offset
|
| 697 |
+
CUTLASS_DEVICE
|
| 698 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 699 |
+
iterator_.add_tile_offset({coord.row(), coord.column()});
|
| 700 |
+
}
|
| 701 |
+
|
| 702 |
+
/// Advances to the next tile in memory.
|
| 703 |
+
CUTLASS_HOST_DEVICE
|
| 704 |
+
RegularTileAccessIterator &operator++() {
|
| 705 |
+
++iterator_;
|
| 706 |
+
return *this;
|
| 707 |
+
}
|
| 708 |
+
|
| 709 |
+
/// Advances to the next tile in memory.
|
| 710 |
+
CUTLASS_HOST_DEVICE
|
| 711 |
+
RegularTileAccessIterator operator++(int) {
|
| 712 |
+
RegularTileAccessIterator prev(*this);
|
| 713 |
+
++iterator_;
|
| 714 |
+
|
| 715 |
+
return prev;
|
| 716 |
+
}
|
| 717 |
+
};
|
| 718 |
+
|
| 719 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 720 |
+
|
| 721 |
+
/// Tile Iterator specialized for row-major crosswise TensorOp formats.
|
| 722 |
+
///
|
| 723 |
+
///
|
| 724 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 725 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 726 |
+
/// WriteableContiguousTileIteratorConcept
|
| 727 |
+
///
|
| 728 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 729 |
+
typename ThreadMap_, int Alignment, int Crosswise>
|
| 730 |
+
class RegularTileAccessIterator<Shape_, Element_,
|
| 731 |
+
layout::RowMajorTensorOpMultiplicandCrosswise<
|
| 732 |
+
sizeof_bits<Element_>::value, Crosswise>,
|
| 733 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 734 |
+
public:
|
| 735 |
+
static_assert(
|
| 736 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 737 |
+
"Specialization for row-major iterator may along advance along the "
|
| 738 |
+
"columns(rank=0) or rows(rank=1) dimension.");
|
| 739 |
+
|
| 740 |
+
using Shape = Shape_;
|
| 741 |
+
using Element = Element_;
|
| 742 |
+
using Layout = layout::RowMajorTensorOpMultiplicandCrosswise<
|
| 743 |
+
sizeof_bits<Element_>::value, Crosswise>;
|
| 744 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 745 |
+
static int const kAlignment = Alignment;
|
| 746 |
+
|
| 747 |
+
using Index = typename Layout::Index;
|
| 748 |
+
using LongIndex = typename Layout::LongIndex;
|
| 749 |
+
|
| 750 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 751 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 752 |
+
|
| 753 |
+
using ThreadMap = ThreadMap_;
|
| 754 |
+
|
| 755 |
+
/// Underlying iterator type
|
| 756 |
+
using UnderlyingIterator = RegularTileAccessIterator<
|
| 757 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
|
| 758 |
+
layout::TensorOpMultiplicandCrosswise<sizeof_bits<Element_>::value,
|
| 759 |
+
Crosswise>,
|
| 760 |
+
(kAdvanceRank == 0 ? 1 : 0), ThreadMap_>;
|
| 761 |
+
|
| 762 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 763 |
+
|
| 764 |
+
private:
|
| 765 |
+
/// Underlying iterator
|
| 766 |
+
UnderlyingIterator iterator_;
|
| 767 |
+
|
| 768 |
+
public:
|
| 769 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 770 |
+
CUTLASS_HOST_DEVICE
|
| 771 |
+
RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
|
| 772 |
+
int thread_id ///< ID of each participating thread
|
| 773 |
+
)
|
| 774 |
+
: iterator_({ref.data(), ref.stride()}, thread_id) {}
|
| 775 |
+
|
| 776 |
+
/// Overrides the internal iteration index
|
| 777 |
+
CUTLASS_HOST_DEVICE
|
| 778 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 779 |
+
|
| 780 |
+
/// Adds a pointer offset in units of Element
|
| 781 |
+
CUTLASS_HOST_DEVICE
|
| 782 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 783 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 784 |
+
}
|
| 785 |
+
|
| 786 |
+
/// Returns a pointer
|
| 787 |
+
CUTLASS_HOST_DEVICE
|
| 788 |
+
AccessType *get() const {
|
| 789 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 790 |
+
}
|
| 791 |
+
|
| 792 |
+
/// Adds a tile offset
|
| 793 |
+
CUTLASS_DEVICE
|
| 794 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 795 |
+
iterator_.add_tile_offset({coord.column(), coord.row()});
|
| 796 |
+
}
|
| 797 |
+
|
| 798 |
+
/// Advances to the next tile in memory.
|
| 799 |
+
CUTLASS_HOST_DEVICE
|
| 800 |
+
RegularTileAccessIterator &operator++() {
|
| 801 |
+
++iterator_;
|
| 802 |
+
return *this;
|
| 803 |
+
}
|
| 804 |
+
|
| 805 |
+
/// Advances to the next tile in memory.
|
| 806 |
+
CUTLASS_HOST_DEVICE
|
| 807 |
+
RegularTileAccessIterator operator++(int) {
|
| 808 |
+
RegularTileAccessIterator prev(*this);
|
| 809 |
+
++iterator_;
|
| 810 |
+
|
| 811 |
+
return prev;
|
| 812 |
+
}
|
| 813 |
+
};
|
| 814 |
+
|
| 815 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 816 |
+
|
| 817 |
+
} // namespace threadblock
|
| 818 |
+
} // namespace transform
|
| 819 |
+
} // namespace cutlass
|
| 820 |
+
|
| 821 |
+
////////////////////////////////////////////////////////////////////////////////
|
build/torch29-cxx11-cu129-aarch64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h
ADDED
|
@@ -0,0 +1,1532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Templates implementing computing the addresses of storing of tiles
|
| 33 |
+
from pitch-linear rank=2 tensors.
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include "cutlass/array.h"
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
#include "cutlass/layout/pitch_linear.h"
|
| 41 |
+
#include "cutlass/layout/tensor_op_multiplicand_sm75.h"
|
| 42 |
+
#include "cutlass/layout/tensor_op_multiplicand_sm80.h"
|
| 43 |
+
#include "cutlass/matrix_coord.h"
|
| 44 |
+
#include "cutlass/matrix_shape.h"
|
| 45 |
+
#include "cutlass/tensor_ref.h"
|
| 46 |
+
#include "cutlass/transform/threadblock/regular_tile_access_iterator.h"
|
| 47 |
+
|
| 48 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 49 |
+
|
| 50 |
+
namespace cutlass {
|
| 51 |
+
namespace transform {
|
| 52 |
+
namespace threadblock {
|
| 53 |
+
|
| 54 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 55 |
+
|
| 56 |
+
/// Tile iterator specialized for congruous arrangements for TensorOps
|
| 57 |
+
///
|
| 58 |
+
///
|
| 59 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 60 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 61 |
+
/// WriteableContiguousTileIteratorConcept
|
| 62 |
+
///
|
| 63 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 64 |
+
typename ThreadMap_, int Alignment>
|
| 65 |
+
class RegularTileAccessIterator<
|
| 66 |
+
Shape_, Element_,
|
| 67 |
+
layout::TensorOpMultiplicandCongruous64b,
|
| 68 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 69 |
+
public:
|
| 70 |
+
static_assert(
|
| 71 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 72 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 73 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 74 |
+
|
| 75 |
+
using Shape = Shape_;
|
| 76 |
+
using Element = Element_;
|
| 77 |
+
using Layout = layout::TensorOpMultiplicandCongruous64b;
|
| 78 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 79 |
+
static int const kAlignment = Alignment;
|
| 80 |
+
|
| 81 |
+
using Index = typename Layout::Index;
|
| 82 |
+
using LongIndex = typename Layout::LongIndex;
|
| 83 |
+
using StrideIndex = typename Layout::Stride::Index;
|
| 84 |
+
|
| 85 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 86 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 87 |
+
|
| 88 |
+
using ThreadMap = ThreadMap_;
|
| 89 |
+
|
| 90 |
+
static_assert(ThreadMap::kThreads / 32 > 1,
|
| 91 |
+
"This tile iterator requires at least two warps.");
|
| 92 |
+
|
| 93 |
+
/// Internal details made public to facilitate introspection
|
| 94 |
+
struct Detail {
|
| 95 |
+
/// This iterator is specialized for an access size that is 128 bits in
|
| 96 |
+
/// length.
|
| 97 |
+
static int const kAccessSizeInBits = 64;
|
| 98 |
+
|
| 99 |
+
static_assert(sizeof_bits<Element_>::value *
|
| 100 |
+
ThreadMap::kElementsPerAccess ==
|
| 101 |
+
kAccessSizeInBits,
|
| 102 |
+
"This iterator requires a policy whose access size is 64b");
|
| 103 |
+
|
| 104 |
+
///< Number of pointers
|
| 105 |
+
static int const kPointerCount = 1;
|
| 106 |
+
};
|
| 107 |
+
|
| 108 |
+
/// Element type per access
|
| 109 |
+
using AccessType = Array<Element, Layout::kElementsPerAccess>;
|
| 110 |
+
|
| 111 |
+
private:
|
| 112 |
+
//
|
| 113 |
+
// Data members
|
| 114 |
+
//
|
| 115 |
+
|
| 116 |
+
/// Stride value
|
| 117 |
+
StrideIndex stride_;
|
| 118 |
+
|
| 119 |
+
/// Internal pointer to first access of tile
|
| 120 |
+
AccessType *pointer_;
|
| 121 |
+
|
| 122 |
+
/// Internal byte offset
|
| 123 |
+
Index byte_offset_;
|
| 124 |
+
|
| 125 |
+
/// Iteration in the contiguous dimension
|
| 126 |
+
int iteration_contiguous_;
|
| 127 |
+
|
| 128 |
+
/// Iteration in the strided dimension
|
| 129 |
+
int iteration_strided_;
|
| 130 |
+
|
| 131 |
+
public:
|
| 132 |
+
|
| 133 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 134 |
+
CUTLASS_HOST_DEVICE
|
| 135 |
+
RegularTileAccessIterator(
|
| 136 |
+
TensorRef ref, ///< Pointer to start of tensor
|
| 137 |
+
int thread_id ///< ID of each participating thread
|
| 138 |
+
):
|
| 139 |
+
stride_(ref.stride(0) / Layout::kElementsPerAccess),
|
| 140 |
+
byte_offset_(0) {
|
| 141 |
+
|
| 142 |
+
layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id);
|
| 143 |
+
|
| 144 |
+
// This is the offset of a thread within a threadblock tile for a specific
|
| 145 |
+
// pointer (units of elements)
|
| 146 |
+
layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base;
|
| 147 |
+
|
| 148 |
+
// initialize pointer
|
| 149 |
+
pointer_ = reinterpret_cast<AccessType *>(ref.data() + ref.offset(thread_offset_in_threadblock_tile));
|
| 150 |
+
|
| 151 |
+
set_iteration_index(0);
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
/// Overrides the internal iteration index
|
| 155 |
+
CUTLASS_HOST_DEVICE
|
| 156 |
+
void set_iteration_index(int index) {
|
| 157 |
+
|
| 158 |
+
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
| 159 |
+
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
/// Adds a pointer offset in units of Element
|
| 163 |
+
CUTLASS_HOST_DEVICE
|
| 164 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 165 |
+
|
| 166 |
+
byte_offset_ += pointer_offset * sizeof(Element);
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
/// Returns a pointer
|
| 170 |
+
CUTLASS_HOST_DEVICE
|
| 171 |
+
AccessType *get() const {
|
| 172 |
+
|
| 173 |
+
AccessType *access_ptr = pointer_;
|
| 174 |
+
|
| 175 |
+
int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ +
|
| 176 |
+
iteration_contiguous_ * ThreadMap::Delta::kContiguous /
|
| 177 |
+
ThreadMap::kElementsPerAccess;
|
| 178 |
+
|
| 179 |
+
char *access_byte_ptr =
|
| 180 |
+
reinterpret_cast<char *>(access_ptr + access_offset);
|
| 181 |
+
|
| 182 |
+
return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
/// Advances to the next tile in memory.
|
| 186 |
+
CUTLASS_HOST_DEVICE
|
| 187 |
+
RegularTileAccessIterator &operator++() {
|
| 188 |
+
++iteration_contiguous_;
|
| 189 |
+
|
| 190 |
+
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
|
| 191 |
+
return *this;
|
| 192 |
+
|
| 193 |
+
// Enter here only if (iteration_contiguous_ ==
|
| 194 |
+
// ThreadMap::Iteration::kContiguous)
|
| 195 |
+
iteration_contiguous_ = 0;
|
| 196 |
+
++iteration_strided_;
|
| 197 |
+
|
| 198 |
+
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
| 199 |
+
return *this;
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
// Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
|
| 203 |
+
// which means we enter the next tile.
|
| 204 |
+
iteration_strided_ = 0;
|
| 205 |
+
|
| 206 |
+
return *this;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
/// Advances to the next tile in memory.
|
| 210 |
+
CUTLASS_HOST_DEVICE
|
| 211 |
+
RegularTileAccessIterator operator++(int) {
|
| 212 |
+
|
| 213 |
+
RegularTileAccessIterator prev(*this);
|
| 214 |
+
|
| 215 |
+
this->operator++();
|
| 216 |
+
|
| 217 |
+
return prev;
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
/// Adds a tile offset
|
| 221 |
+
CUTLASS_DEVICE
|
| 222 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 223 |
+
|
| 224 |
+
add_pointer_offset(
|
| 225 |
+
coord.contiguous() * Shape::kContiguous +
|
| 226 |
+
coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess);
|
| 227 |
+
}
|
| 228 |
+
};
|
| 229 |
+
|
| 230 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 231 |
+
|
| 232 |
+
/// Tile Iterator specialized for column-major congruous TensorOp formats.
|
| 233 |
+
///
|
| 234 |
+
///
|
| 235 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 236 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 237 |
+
/// WriteableContiguousTileIteratorConcept
|
| 238 |
+
///
|
| 239 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 240 |
+
typename ThreadMap_, int Alignment>
|
| 241 |
+
class RegularTileAccessIterator<
|
| 242 |
+
Shape_, Element_,
|
| 243 |
+
layout::ColumnMajorTensorOpMultiplicandCongruous64b,
|
| 244 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 245 |
+
public:
|
| 246 |
+
static_assert(
|
| 247 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 248 |
+
"Specialization for column-major iterator may along advance along the "
|
| 249 |
+
"columns(rank=0) or rows(rank=1) dimension.");
|
| 250 |
+
|
| 251 |
+
using Shape = Shape_;
|
| 252 |
+
using Element = Element_;
|
| 253 |
+
using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous64b;
|
| 254 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 255 |
+
static int const kAlignment = Alignment;
|
| 256 |
+
|
| 257 |
+
using Index = typename Layout::Index;
|
| 258 |
+
using LongIndex = typename Layout::LongIndex;
|
| 259 |
+
|
| 260 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 261 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 262 |
+
|
| 263 |
+
using ThreadMap = ThreadMap_;
|
| 264 |
+
|
| 265 |
+
/// Underlying iterator type
|
| 266 |
+
using UnderlyingIterator = RegularTileAccessIterator<
|
| 267 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
|
| 268 |
+
layout::TensorOpMultiplicandCongruous64b,
|
| 269 |
+
(kAdvanceRank == 0 ? 0 : 1), ThreadMap_>;
|
| 270 |
+
|
| 271 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 272 |
+
|
| 273 |
+
private:
|
| 274 |
+
/// Underlying iterator
|
| 275 |
+
UnderlyingIterator iterator_;
|
| 276 |
+
|
| 277 |
+
public:
|
| 278 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 279 |
+
CUTLASS_HOST_DEVICE
|
| 280 |
+
RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
|
| 281 |
+
int thread_id ///< ID of each participating thread
|
| 282 |
+
)
|
| 283 |
+
: iterator_({ref.data(), ref.stride()}, thread_id) {}
|
| 284 |
+
|
| 285 |
+
/// Overrides the internal iteration index
|
| 286 |
+
CUTLASS_HOST_DEVICE
|
| 287 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 288 |
+
|
| 289 |
+
/// Adds a pointer offset in units of Element
|
| 290 |
+
CUTLASS_HOST_DEVICE
|
| 291 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 292 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
/// Returns a pointer
|
| 296 |
+
CUTLASS_HOST_DEVICE
|
| 297 |
+
AccessType *get() const {
|
| 298 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
/// Adds a tile offset
|
| 302 |
+
CUTLASS_DEVICE
|
| 303 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 304 |
+
iterator_.add_tile_offset({coord.row(), coord.column()});
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
/// Advances to the next tile in memory.
|
| 308 |
+
CUTLASS_HOST_DEVICE
|
| 309 |
+
RegularTileAccessIterator &operator++() {
|
| 310 |
+
++iterator_;
|
| 311 |
+
return *this;
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
/// Advances to the next tile in memory.
|
| 315 |
+
CUTLASS_HOST_DEVICE
|
| 316 |
+
RegularTileAccessIterator operator++(int) {
|
| 317 |
+
RegularTileAccessIterator prev(*this);
|
| 318 |
+
++iterator_;
|
| 319 |
+
|
| 320 |
+
return prev;
|
| 321 |
+
}
|
| 322 |
+
};
|
| 323 |
+
|
| 324 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 325 |
+
|
| 326 |
+
/// Tile Iterator specialized for row-major congruous TensorOp formats.
|
| 327 |
+
///
|
| 328 |
+
///
|
| 329 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 330 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 331 |
+
/// WriteableContiguousTileIteratorConcept
|
| 332 |
+
///
|
| 333 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 334 |
+
typename ThreadMap_, int Alignment>
|
| 335 |
+
class RegularTileAccessIterator<Shape_, Element_,
|
| 336 |
+
layout::RowMajorTensorOpMultiplicandCongruous64b,
|
| 337 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 338 |
+
public:
|
| 339 |
+
static_assert(
|
| 340 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 341 |
+
"Specialization for row-major iterator may along advance along the "
|
| 342 |
+
"columns(rank=0) or rows(rank=1) dimension.");
|
| 343 |
+
|
| 344 |
+
using Shape = Shape_;
|
| 345 |
+
using Element = Element_;
|
| 346 |
+
using Layout = layout::RowMajorTensorOpMultiplicandCongruous64b;
|
| 347 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 348 |
+
static int const kAlignment = Alignment;
|
| 349 |
+
|
| 350 |
+
using Index = typename Layout::Index;
|
| 351 |
+
using LongIndex = typename Layout::LongIndex;
|
| 352 |
+
|
| 353 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 354 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 355 |
+
|
| 356 |
+
using ThreadMap = ThreadMap_;
|
| 357 |
+
|
| 358 |
+
/// Underlying iterator type
|
| 359 |
+
using UnderlyingIterator = RegularTileAccessIterator<
|
| 360 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
|
| 361 |
+
layout::TensorOpMultiplicandCongruous64b,
|
| 362 |
+
(kAdvanceRank == 0 ? 1 : 0), ThreadMap_>;
|
| 363 |
+
|
| 364 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 365 |
+
|
| 366 |
+
private:
|
| 367 |
+
/// Underlying iterator
|
| 368 |
+
UnderlyingIterator iterator_;
|
| 369 |
+
|
| 370 |
+
public:
|
| 371 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 372 |
+
CUTLASS_HOST_DEVICE
|
| 373 |
+
RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
|
| 374 |
+
int thread_id ///< ID of each participating thread
|
| 375 |
+
)
|
| 376 |
+
: iterator_({ref.data(), ref.stride()}, thread_id) {}
|
| 377 |
+
|
| 378 |
+
/// Overrides the internal iteration index
|
| 379 |
+
CUTLASS_HOST_DEVICE
|
| 380 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 381 |
+
|
| 382 |
+
/// Adds a pointer offset in units of Element
|
| 383 |
+
CUTLASS_HOST_DEVICE
|
| 384 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 385 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
/// Returns a pointer
|
| 389 |
+
CUTLASS_HOST_DEVICE
|
| 390 |
+
AccessType *get() const {
|
| 391 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
/// Adds a tile offset
|
| 395 |
+
CUTLASS_DEVICE
|
| 396 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 397 |
+
iterator_.add_tile_offset({coord.column(), coord.row()});
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
/// Advances to the next tile in memory.
|
| 401 |
+
CUTLASS_HOST_DEVICE
|
| 402 |
+
RegularTileAccessIterator &operator++() {
|
| 403 |
+
++iterator_;
|
| 404 |
+
return *this;
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
/// Advances to the next tile in memory.
|
| 408 |
+
CUTLASS_HOST_DEVICE
|
| 409 |
+
RegularTileAccessIterator operator++(int) {
|
| 410 |
+
RegularTileAccessIterator prev(*this);
|
| 411 |
+
++iterator_;
|
| 412 |
+
|
| 413 |
+
return prev;
|
| 414 |
+
}
|
| 415 |
+
};
|
| 416 |
+
|
| 417 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 418 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 419 |
+
|
| 420 |
+
/// Tile iterator specialized for crosswise arrangements for TensorOps
|
| 421 |
+
///
|
| 422 |
+
///
|
| 423 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 424 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 425 |
+
/// WriteableContiguousTileIteratorConcept
|
| 426 |
+
///
|
| 427 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 428 |
+
typename ThreadMap_, int Alignment>
|
| 429 |
+
class RegularTileAccessIterator<
|
| 430 |
+
Shape_, Element_,
|
| 431 |
+
layout::TensorOpMultiplicand64bCrosswise,
|
| 432 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 433 |
+
public:
|
| 434 |
+
static_assert(
|
| 435 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 436 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 437 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 438 |
+
|
| 439 |
+
using Shape = Shape_;
|
| 440 |
+
using Element = Element_;
|
| 441 |
+
using Layout = layout::TensorOpMultiplicand64bCrosswise;
|
| 442 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 443 |
+
static int const kAlignment = Alignment;
|
| 444 |
+
|
| 445 |
+
using Index = typename Layout::Index;
|
| 446 |
+
using LongIndex = typename Layout::LongIndex;
|
| 447 |
+
using StrideIndex = typename Layout::Stride::Index;
|
| 448 |
+
|
| 449 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 450 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 451 |
+
|
| 452 |
+
using ThreadMap = ThreadMap_;
|
| 453 |
+
|
| 454 |
+
static_assert(ThreadMap::kThreads / 32 > 1,
|
| 455 |
+
"This tile iterator requires at least two warps.");
|
| 456 |
+
|
| 457 |
+
/// Internal details made public to facilitate introspection
|
| 458 |
+
struct Detail {
|
| 459 |
+
/// This iterator is specialized for an access size that is 128 bits in
|
| 460 |
+
/// length.
|
| 461 |
+
static int const kAccessSizeInBits = 64;
|
| 462 |
+
|
| 463 |
+
static_assert(sizeof_bits<Element_>::value *
|
| 464 |
+
ThreadMap::kElementsPerAccess ==
|
| 465 |
+
kAccessSizeInBits,
|
| 466 |
+
"This iterator requires a policy whose access size is 64b");
|
| 467 |
+
|
| 468 |
+
///< Number of pointers - two pointers are needed if making more than 4 iterations along
|
| 469 |
+
///< strided dimension
|
| 470 |
+
static int const kPointerCount = (ThreadMap::Iterations::kStrided > 4 ? 2 : 1);
|
| 471 |
+
};
|
| 472 |
+
|
| 473 |
+
/// Element type per access
|
| 474 |
+
using AccessType = Array<Element, Layout::kElementsPerAccess>;
|
| 475 |
+
|
| 476 |
+
private:
|
| 477 |
+
//
|
| 478 |
+
// Data members
|
| 479 |
+
//
|
| 480 |
+
|
| 481 |
+
/// Stride value
|
| 482 |
+
StrideIndex stride_;
|
| 483 |
+
|
| 484 |
+
/// Internal pointer to first access of tile
|
| 485 |
+
AccessType *pointer_;
|
| 486 |
+
|
| 487 |
+
/// Internal byte offset
|
| 488 |
+
Index byte_offset_[Detail::kPointerCount];
|
| 489 |
+
|
| 490 |
+
/// Iteration in the contiguous dimension
|
| 491 |
+
int iteration_contiguous_;
|
| 492 |
+
|
| 493 |
+
/// Iteration in the strided dimension
|
| 494 |
+
int iteration_strided_;
|
| 495 |
+
|
| 496 |
+
public:
|
| 497 |
+
|
| 498 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 499 |
+
CUTLASS_DEVICE
|
| 500 |
+
RegularTileAccessIterator(
|
| 501 |
+
TensorRef ref, ///< Pointer to start of tensor
|
| 502 |
+
int thread_id ///< ID of each participating thread
|
| 503 |
+
):
|
| 504 |
+
stride_(ref.stride(0) / ThreadMap::kElementsPerAccess) {
|
| 505 |
+
|
| 506 |
+
layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id);
|
| 507 |
+
|
| 508 |
+
// This is the offset of a thread within a threadblock tile for a specific
|
| 509 |
+
// pointer (units of elements)
|
| 510 |
+
layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base;
|
| 511 |
+
|
| 512 |
+
// initialize pointer
|
| 513 |
+
pointer_ = reinterpret_cast<AccessType *>(ref.data());
|
| 514 |
+
|
| 515 |
+
byte_offset_[0] = ref.offset(thread_offset_in_threadblock_tile) * sizeof(Element);
|
| 516 |
+
|
| 517 |
+
if (Detail::kPointerCount == 2) {
|
| 518 |
+
byte_offset_[1] = byte_offset_[0] ^ 8;
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
set_iteration_index(0);
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
/// Overrides the internal iteration index
|
| 525 |
+
CUTLASS_HOST_DEVICE
|
| 526 |
+
void set_iteration_index(int index) {
|
| 527 |
+
|
| 528 |
+
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
| 529 |
+
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
/// Adds a pointer offset in units of Element
|
| 533 |
+
CUTLASS_HOST_DEVICE
|
| 534 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 535 |
+
|
| 536 |
+
pointer_ += pointer_offset / ThreadMap::kElementsPerAccess;
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
/// Returns a pointer
|
| 540 |
+
CUTLASS_DEVICE
|
| 541 |
+
AccessType *get() const {
|
| 542 |
+
|
| 543 |
+
// Map the logical contiguous and strided access to the internal swizzled structure.
|
| 544 |
+
int uniform_offset = (iteration_strided_ & 0x3) * stride_ + (iteration_strided_ >> 3) * 16 + stride_ * ThreadMap::Delta::kContiguous * iteration_contiguous_;
|
| 545 |
+
|
| 546 |
+
char *access_byte_ptr = reinterpret_cast<char *>(pointer_ + uniform_offset);
|
| 547 |
+
|
| 548 |
+
int byte_offset;
|
| 549 |
+
|
| 550 |
+
// This iterator may require two byte offsets if it must load more than 8 rows (or 2 iterations)
|
| 551 |
+
// in the strided dimension
|
| 552 |
+
if (Detail::kPointerCount == 2 && (iteration_strided_ & 0x4)) {
|
| 553 |
+
byte_offset = byte_offset_[1];
|
| 554 |
+
}
|
| 555 |
+
else {
|
| 556 |
+
byte_offset = byte_offset_[0];
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset);
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
/// Advances to the next tile in memory.
|
| 563 |
+
CUTLASS_HOST_DEVICE
|
| 564 |
+
RegularTileAccessIterator &operator++() {
|
| 565 |
+
++iteration_contiguous_;
|
| 566 |
+
|
| 567 |
+
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
|
| 568 |
+
return *this;
|
| 569 |
+
|
| 570 |
+
// Enter here only if (iteration_contiguous_ ==
|
| 571 |
+
// ThreadMap::Iteration::kContiguous)
|
| 572 |
+
iteration_contiguous_ = 0;
|
| 573 |
+
++iteration_strided_;
|
| 574 |
+
|
| 575 |
+
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
| 576 |
+
return *this;
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
// Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
|
| 580 |
+
// which means we enter the next tile.
|
| 581 |
+
iteration_strided_ = 0;
|
| 582 |
+
|
| 583 |
+
return *this;
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
/// Advances to the next tile in memory.
|
| 587 |
+
CUTLASS_HOST_DEVICE
|
| 588 |
+
RegularTileAccessIterator operator++(int) {
|
| 589 |
+
|
| 590 |
+
RegularTileAccessIterator prev(*this);
|
| 591 |
+
|
| 592 |
+
this->operator++();
|
| 593 |
+
|
| 594 |
+
return prev;
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
/// Adds a tile offset
|
| 598 |
+
CUTLASS_DEVICE
|
| 599 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 600 |
+
|
| 601 |
+
add_pointer_offset(coord.strided() * Shape::kStrided + coord.contiguous() * Shape::kContiguous * stride_);
|
| 602 |
+
}
|
| 603 |
+
};
|
| 604 |
+
|
| 605 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 606 |
+
|
| 607 |
+
/// Tile Iterator specialized for column-major crosswise TensorOp formats.
|
| 608 |
+
///
|
| 609 |
+
///
|
| 610 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 611 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 612 |
+
/// WriteableContiguousTileIteratorConcept
|
| 613 |
+
///
|
| 614 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 615 |
+
typename ThreadMap_, int Alignment>
|
| 616 |
+
class RegularTileAccessIterator<
|
| 617 |
+
Shape_, Element_,
|
| 618 |
+
layout::ColumnMajorTensorOpMultiplicand64bCrosswise,
|
| 619 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 620 |
+
public:
|
| 621 |
+
static_assert(
|
| 622 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 623 |
+
"Specialization for column-major iterator may along advance along the "
|
| 624 |
+
"columns(rank=0) or rows(rank=1) dimension.");
|
| 625 |
+
|
| 626 |
+
using Shape = Shape_;
|
| 627 |
+
using Element = Element_;
|
| 628 |
+
using Layout = layout::ColumnMajorTensorOpMultiplicand64bCrosswise;
|
| 629 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 630 |
+
static int const kAlignment = Alignment;
|
| 631 |
+
|
| 632 |
+
using Index = typename Layout::Index;
|
| 633 |
+
using LongIndex = typename Layout::LongIndex;
|
| 634 |
+
|
| 635 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 636 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 637 |
+
|
| 638 |
+
using ThreadMap = ThreadMap_;
|
| 639 |
+
|
| 640 |
+
/// Underlying iterator type
|
| 641 |
+
using UnderlyingIterator = RegularTileAccessIterator<
|
| 642 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
|
| 643 |
+
layout::TensorOpMultiplicand64bCrosswise,
|
| 644 |
+
(kAdvanceRank == 0 ? 0 : 1), ThreadMap_>;
|
| 645 |
+
|
| 646 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 647 |
+
|
| 648 |
+
private:
|
| 649 |
+
/// Underlying iterator
|
| 650 |
+
UnderlyingIterator iterator_;
|
| 651 |
+
|
| 652 |
+
public:
|
| 653 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 654 |
+
CUTLASS_HOST_DEVICE
|
| 655 |
+
RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
|
| 656 |
+
int thread_id ///< ID of each participating thread
|
| 657 |
+
)
|
| 658 |
+
: iterator_({ref.data(), ref.stride()}, thread_id) {}
|
| 659 |
+
|
| 660 |
+
/// Overrides the internal iteration index
|
| 661 |
+
CUTLASS_HOST_DEVICE
|
| 662 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 663 |
+
|
| 664 |
+
/// Adds a pointer offset in units of Element
|
| 665 |
+
CUTLASS_HOST_DEVICE
|
| 666 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 667 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 668 |
+
}
|
| 669 |
+
|
| 670 |
+
/// Returns a pointer
|
| 671 |
+
CUTLASS_HOST_DEVICE
|
| 672 |
+
AccessType *get() const {
|
| 673 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 674 |
+
}
|
| 675 |
+
|
| 676 |
+
/// Adds a tile offset
|
| 677 |
+
CUTLASS_DEVICE
|
| 678 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 679 |
+
iterator_.add_tile_offset({coord.row(), coord.column()});
|
| 680 |
+
}
|
| 681 |
+
|
| 682 |
+
/// Advances to the next tile in memory.
|
| 683 |
+
CUTLASS_HOST_DEVICE
|
| 684 |
+
RegularTileAccessIterator &operator++() {
|
| 685 |
+
++iterator_;
|
| 686 |
+
return *this;
|
| 687 |
+
}
|
| 688 |
+
|
| 689 |
+
/// Advances to the next tile in memory.
|
| 690 |
+
CUTLASS_HOST_DEVICE
|
| 691 |
+
RegularTileAccessIterator operator++(int) {
|
| 692 |
+
RegularTileAccessIterator prev(*this);
|
| 693 |
+
++iterator_;
|
| 694 |
+
|
| 695 |
+
return prev;
|
| 696 |
+
}
|
| 697 |
+
};
|
| 698 |
+
|
| 699 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 700 |
+
|
| 701 |
+
/// Tile Iterator specialized for row-major crosswise TensorOp formats.
|
| 702 |
+
///
|
| 703 |
+
///
|
| 704 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 705 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 706 |
+
/// WriteableContiguousTileIteratorConcept
|
| 707 |
+
///
|
| 708 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 709 |
+
typename ThreadMap_, int Alignment>
|
| 710 |
+
class RegularTileAccessIterator<Shape_, Element_,
|
| 711 |
+
layout::RowMajorTensorOpMultiplicand64bCrosswise,
|
| 712 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 713 |
+
public:
|
| 714 |
+
static_assert(
|
| 715 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 716 |
+
"Specialization for row-major iterator may along advance along the "
|
| 717 |
+
"columns(rank=0) or rows(rank=1) dimension.");
|
| 718 |
+
|
| 719 |
+
using Shape = Shape_;
|
| 720 |
+
using Element = Element_;
|
| 721 |
+
using Layout = layout::RowMajorTensorOpMultiplicand64bCrosswise;
|
| 722 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 723 |
+
static int const kAlignment = Alignment;
|
| 724 |
+
|
| 725 |
+
using Index = typename Layout::Index;
|
| 726 |
+
using LongIndex = typename Layout::LongIndex;
|
| 727 |
+
|
| 728 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 729 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 730 |
+
|
| 731 |
+
using ThreadMap = ThreadMap_;
|
| 732 |
+
|
| 733 |
+
/// Underlying iterator type
|
| 734 |
+
using UnderlyingIterator = RegularTileAccessIterator<
|
| 735 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
|
| 736 |
+
layout::TensorOpMultiplicand64bCrosswise,
|
| 737 |
+
(kAdvanceRank == 0 ? 1 : 0), ThreadMap_>;
|
| 738 |
+
|
| 739 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 740 |
+
|
| 741 |
+
private:
|
| 742 |
+
/// Underlying iterator
|
| 743 |
+
UnderlyingIterator iterator_;
|
| 744 |
+
|
| 745 |
+
public:
|
| 746 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 747 |
+
CUTLASS_HOST_DEVICE
|
| 748 |
+
RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
|
| 749 |
+
int thread_id ///< ID of each participating thread
|
| 750 |
+
)
|
| 751 |
+
: iterator_({ref.data(), ref.stride()}, thread_id) {}
|
| 752 |
+
|
| 753 |
+
/// Overrides the internal iteration index
|
| 754 |
+
CUTLASS_HOST_DEVICE
|
| 755 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 756 |
+
|
| 757 |
+
/// Adds a pointer offset in units of Element
|
| 758 |
+
CUTLASS_HOST_DEVICE
|
| 759 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 760 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 761 |
+
}
|
| 762 |
+
|
| 763 |
+
/// Returns a pointer
|
| 764 |
+
CUTLASS_HOST_DEVICE
|
| 765 |
+
AccessType *get() const {
|
| 766 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 767 |
+
}
|
| 768 |
+
|
| 769 |
+
/// Adds a tile offset
|
| 770 |
+
CUTLASS_DEVICE
|
| 771 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 772 |
+
iterator_.add_tile_offset({coord.column(), coord.row()});
|
| 773 |
+
}
|
| 774 |
+
|
| 775 |
+
/// Advances to the next tile in memory.
|
| 776 |
+
CUTLASS_HOST_DEVICE
|
| 777 |
+
RegularTileAccessIterator &operator++() {
|
| 778 |
+
++iterator_;
|
| 779 |
+
return *this;
|
| 780 |
+
}
|
| 781 |
+
|
| 782 |
+
/// Advances to the next tile in memory.
|
| 783 |
+
CUTLASS_HOST_DEVICE
|
| 784 |
+
RegularTileAccessIterator operator++(int) {
|
| 785 |
+
RegularTileAccessIterator prev(*this);
|
| 786 |
+
++iterator_;
|
| 787 |
+
|
| 788 |
+
return prev;
|
| 789 |
+
}
|
| 790 |
+
};
|
| 791 |
+
|
| 792 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 793 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 794 |
+
|
| 795 |
+
/// Tile iterator specialized for congruous arrangements for TensorOps
|
| 796 |
+
///
|
| 797 |
+
///
|
| 798 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 799 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 800 |
+
/// WriteableContiguousTileIteratorConcept
|
| 801 |
+
///
|
| 802 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 803 |
+
typename ThreadMap_, int Alignment>
|
| 804 |
+
class RegularTileAccessIterator<
|
| 805 |
+
Shape_, Element_,
|
| 806 |
+
layout::TensorOpMultiplicandCongruous128b,
|
| 807 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 808 |
+
public:
|
| 809 |
+
static_assert(
|
| 810 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 811 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 812 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 813 |
+
|
| 814 |
+
using Shape = Shape_;
|
| 815 |
+
using Element = Element_;
|
| 816 |
+
using Layout = layout::TensorOpMultiplicandCongruous128b;
|
| 817 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 818 |
+
static int const kAlignment = Alignment;
|
| 819 |
+
|
| 820 |
+
using Index = typename Layout::Index;
|
| 821 |
+
using LongIndex = typename Layout::LongIndex;
|
| 822 |
+
using StrideIndex = typename Layout::Stride::Index;
|
| 823 |
+
|
| 824 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 825 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 826 |
+
|
| 827 |
+
using ThreadMap = ThreadMap_;
|
| 828 |
+
|
| 829 |
+
static_assert(ThreadMap::kThreads / 32 > 1,
|
| 830 |
+
"This tile iterator requires at least two warps.");
|
| 831 |
+
|
| 832 |
+
/// Internal details made public to facilitate introspection
|
| 833 |
+
struct Detail {
|
| 834 |
+
/// This iterator is specialized for an access size that is 128 bits in
|
| 835 |
+
/// length.
|
| 836 |
+
static int const kAccessSizeInBits = 128;
|
| 837 |
+
|
| 838 |
+
static_assert(sizeof_bits<Element_>::value *
|
| 839 |
+
ThreadMap::kElementsPerAccess ==
|
| 840 |
+
kAccessSizeInBits,
|
| 841 |
+
"This iterator requires a policy whose access size is 128b");
|
| 842 |
+
|
| 843 |
+
///< Number of pointers
|
| 844 |
+
static int const kPointerCount = 1;
|
| 845 |
+
};
|
| 846 |
+
|
| 847 |
+
/// Element type per access
|
| 848 |
+
using AccessType = Array<Element, Layout::kElementsPerAccess>;
|
| 849 |
+
|
| 850 |
+
private:
|
| 851 |
+
//
|
| 852 |
+
// Data members
|
| 853 |
+
//
|
| 854 |
+
|
| 855 |
+
/// Stride value
|
| 856 |
+
StrideIndex stride_;
|
| 857 |
+
|
| 858 |
+
/// Internal pointer to first access of tile
|
| 859 |
+
AccessType *pointer_;
|
| 860 |
+
|
| 861 |
+
/// Internal byte offset
|
| 862 |
+
Index byte_offset_;
|
| 863 |
+
|
| 864 |
+
/// Iteration in the contiguous dimension
|
| 865 |
+
int iteration_contiguous_;
|
| 866 |
+
|
| 867 |
+
/// Iteration in the strided dimension
|
| 868 |
+
int iteration_strided_;
|
| 869 |
+
|
| 870 |
+
public:
|
| 871 |
+
|
| 872 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 873 |
+
CUTLASS_HOST_DEVICE
|
| 874 |
+
RegularTileAccessIterator(
|
| 875 |
+
TensorRef ref, ///< Pointer to start of tensor
|
| 876 |
+
int thread_id ///< ID of each participating thread
|
| 877 |
+
):
|
| 878 |
+
stride_(ref.stride(0) / Layout::kElementsPerAccess),
|
| 879 |
+
byte_offset_(0) {
|
| 880 |
+
|
| 881 |
+
layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id);
|
| 882 |
+
|
| 883 |
+
// This is the offset of a thread within a threadblock tile for a specific
|
| 884 |
+
// pointer (units of elements)
|
| 885 |
+
layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base;
|
| 886 |
+
|
| 887 |
+
// initialize pointer
|
| 888 |
+
pointer_ = reinterpret_cast<AccessType *>(ref.data() + ref.offset(thread_offset_in_threadblock_tile));
|
| 889 |
+
|
| 890 |
+
set_iteration_index(0);
|
| 891 |
+
}
|
| 892 |
+
|
| 893 |
+
/// Overrides the internal iteration index
|
| 894 |
+
CUTLASS_HOST_DEVICE
|
| 895 |
+
void set_iteration_index(int index) {
|
| 896 |
+
|
| 897 |
+
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
| 898 |
+
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
| 899 |
+
}
|
| 900 |
+
|
| 901 |
+
/// Adds a pointer offset in units of Element
|
| 902 |
+
CUTLASS_HOST_DEVICE
|
| 903 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 904 |
+
|
| 905 |
+
byte_offset_ += pointer_offset * sizeof(Element);
|
| 906 |
+
}
|
| 907 |
+
|
| 908 |
+
/// Returns a pointer
|
| 909 |
+
CUTLASS_HOST_DEVICE
|
| 910 |
+
AccessType *get() const {
|
| 911 |
+
|
| 912 |
+
AccessType *access_ptr = pointer_;
|
| 913 |
+
|
| 914 |
+
int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ +
|
| 915 |
+
iteration_contiguous_ * ThreadMap::Delta::kContiguous /
|
| 916 |
+
ThreadMap::kElementsPerAccess;
|
| 917 |
+
|
| 918 |
+
char *access_byte_ptr =
|
| 919 |
+
reinterpret_cast<char *>(access_ptr + access_offset);
|
| 920 |
+
|
| 921 |
+
return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
|
| 922 |
+
}
|
| 923 |
+
|
| 924 |
+
/// Advances to the next tile in memory.
|
| 925 |
+
CUTLASS_HOST_DEVICE
|
| 926 |
+
RegularTileAccessIterator &operator++() {
|
| 927 |
+
++iteration_contiguous_;
|
| 928 |
+
|
| 929 |
+
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
|
| 930 |
+
return *this;
|
| 931 |
+
|
| 932 |
+
// Enter here only if (iteration_contiguous_ ==
|
| 933 |
+
// ThreadMap::Iteration::kContiguous)
|
| 934 |
+
iteration_contiguous_ = 0;
|
| 935 |
+
++iteration_strided_;
|
| 936 |
+
|
| 937 |
+
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
| 938 |
+
return *this;
|
| 939 |
+
}
|
| 940 |
+
|
| 941 |
+
// Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
|
| 942 |
+
// which means we enter the next tile.
|
| 943 |
+
iteration_strided_ = 0;
|
| 944 |
+
|
| 945 |
+
return *this;
|
| 946 |
+
}
|
| 947 |
+
|
| 948 |
+
/// Advances to the next tile in memory.
|
| 949 |
+
CUTLASS_HOST_DEVICE
|
| 950 |
+
RegularTileAccessIterator operator++(int) {
|
| 951 |
+
|
| 952 |
+
RegularTileAccessIterator prev(*this);
|
| 953 |
+
|
| 954 |
+
this->operator++();
|
| 955 |
+
|
| 956 |
+
return prev;
|
| 957 |
+
}
|
| 958 |
+
|
| 959 |
+
/// Adds a tile offset
|
| 960 |
+
CUTLASS_DEVICE
|
| 961 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 962 |
+
|
| 963 |
+
add_pointer_offset(
|
| 964 |
+
coord.contiguous() * Shape::kContiguous +
|
| 965 |
+
coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess);
|
| 966 |
+
}
|
| 967 |
+
};
|
| 968 |
+
|
| 969 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 970 |
+
|
| 971 |
+
/// Tile Iterator specialized for column-major congruous TensorOp formats.
|
| 972 |
+
///
|
| 973 |
+
///
|
| 974 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 975 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 976 |
+
/// WriteableContiguousTileIteratorConcept
|
| 977 |
+
///
|
| 978 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 979 |
+
typename ThreadMap_, int Alignment>
|
| 980 |
+
class RegularTileAccessIterator<
|
| 981 |
+
Shape_, Element_,
|
| 982 |
+
layout::ColumnMajorTensorOpMultiplicandCongruous128b,
|
| 983 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 984 |
+
public:
|
| 985 |
+
static_assert(
|
| 986 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 987 |
+
"Specialization for column-major iterator may along advance along the "
|
| 988 |
+
"columns(rank=0) or rows(rank=1) dimension.");
|
| 989 |
+
|
| 990 |
+
using Shape = Shape_;
|
| 991 |
+
using Element = Element_;
|
| 992 |
+
using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous128b;
|
| 993 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 994 |
+
static int const kAlignment = Alignment;
|
| 995 |
+
|
| 996 |
+
using Index = typename Layout::Index;
|
| 997 |
+
using LongIndex = typename Layout::LongIndex;
|
| 998 |
+
|
| 999 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 1000 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 1001 |
+
|
| 1002 |
+
using ThreadMap = ThreadMap_;
|
| 1003 |
+
|
| 1004 |
+
/// Underlying iterator type
|
| 1005 |
+
using UnderlyingIterator = RegularTileAccessIterator<
|
| 1006 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
|
| 1007 |
+
layout::TensorOpMultiplicandCongruous128b,
|
| 1008 |
+
(kAdvanceRank == 0 ? 0 : 1), ThreadMap_>;
|
| 1009 |
+
|
| 1010 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 1011 |
+
|
| 1012 |
+
private:
|
| 1013 |
+
/// Underlying iterator
|
| 1014 |
+
UnderlyingIterator iterator_;
|
| 1015 |
+
|
| 1016 |
+
public:
|
| 1017 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 1018 |
+
CUTLASS_HOST_DEVICE
|
| 1019 |
+
RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
|
| 1020 |
+
int thread_id ///< ID of each participating thread
|
| 1021 |
+
)
|
| 1022 |
+
: iterator_({ref.data(), ref.stride()}, thread_id) {}
|
| 1023 |
+
|
| 1024 |
+
/// Overrides the internal iteration index
|
| 1025 |
+
CUTLASS_HOST_DEVICE
|
| 1026 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 1027 |
+
|
| 1028 |
+
/// Adds a pointer offset in units of Element
|
| 1029 |
+
CUTLASS_HOST_DEVICE
|
| 1030 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1031 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 1032 |
+
}
|
| 1033 |
+
|
| 1034 |
+
/// Returns a pointer
|
| 1035 |
+
CUTLASS_HOST_DEVICE
|
| 1036 |
+
AccessType *get() const {
|
| 1037 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 1038 |
+
}
|
| 1039 |
+
|
| 1040 |
+
/// Adds a tile offset
|
| 1041 |
+
CUTLASS_DEVICE
|
| 1042 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 1043 |
+
iterator_.add_tile_offset({coord.row(), coord.column()});
|
| 1044 |
+
}
|
| 1045 |
+
|
| 1046 |
+
/// Advances to the next tile in memory.
|
| 1047 |
+
CUTLASS_HOST_DEVICE
|
| 1048 |
+
RegularTileAccessIterator &operator++() {
|
| 1049 |
+
++iterator_;
|
| 1050 |
+
return *this;
|
| 1051 |
+
}
|
| 1052 |
+
|
| 1053 |
+
/// Advances to the next tile in memory.
|
| 1054 |
+
CUTLASS_HOST_DEVICE
|
| 1055 |
+
RegularTileAccessIterator operator++(int) {
|
| 1056 |
+
RegularTileAccessIterator prev(*this);
|
| 1057 |
+
++iterator_;
|
| 1058 |
+
|
| 1059 |
+
return prev;
|
| 1060 |
+
}
|
| 1061 |
+
};
|
| 1062 |
+
|
| 1063 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1064 |
+
|
| 1065 |
+
/// Tile Iterator specialized for row-major congruous TensorOp formats.
|
| 1066 |
+
///
|
| 1067 |
+
///
|
| 1068 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 1069 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 1070 |
+
/// WriteableContiguousTileIteratorConcept
|
| 1071 |
+
///
|
| 1072 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 1073 |
+
typename ThreadMap_, int Alignment>
|
| 1074 |
+
class RegularTileAccessIterator<Shape_, Element_,
|
| 1075 |
+
layout::RowMajorTensorOpMultiplicandCongruous128b,
|
| 1076 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 1077 |
+
public:
|
| 1078 |
+
static_assert(
|
| 1079 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 1080 |
+
"Specialization for row-major iterator may along advance along the "
|
| 1081 |
+
"columns(rank=0) or rows(rank=1) dimension.");
|
| 1082 |
+
|
| 1083 |
+
using Shape = Shape_;
|
| 1084 |
+
using Element = Element_;
|
| 1085 |
+
using Layout = layout::RowMajorTensorOpMultiplicandCongruous128b;
|
| 1086 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 1087 |
+
static int const kAlignment = Alignment;
|
| 1088 |
+
|
| 1089 |
+
using Index = typename Layout::Index;
|
| 1090 |
+
using LongIndex = typename Layout::LongIndex;
|
| 1091 |
+
|
| 1092 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 1093 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 1094 |
+
|
| 1095 |
+
using ThreadMap = ThreadMap_;
|
| 1096 |
+
|
| 1097 |
+
/// Underlying iterator type
|
| 1098 |
+
using UnderlyingIterator = RegularTileAccessIterator<
|
| 1099 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
|
| 1100 |
+
layout::TensorOpMultiplicandCongruous128b,
|
| 1101 |
+
(kAdvanceRank == 0 ? 1 : 0), ThreadMap_>;
|
| 1102 |
+
|
| 1103 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 1104 |
+
|
| 1105 |
+
private:
|
| 1106 |
+
/// Underlying iterator
|
| 1107 |
+
UnderlyingIterator iterator_;
|
| 1108 |
+
|
| 1109 |
+
public:
|
| 1110 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 1111 |
+
CUTLASS_HOST_DEVICE
|
| 1112 |
+
RegularTileAccessIterator(
|
| 1113 |
+
TensorRef ref, ///< Pointer to start of tensor
|
| 1114 |
+
int thread_id ///< ID of each participating thread
|
| 1115 |
+
):
|
| 1116 |
+
iterator_({ref.data(), ref.stride()}, thread_id) {}
|
| 1117 |
+
|
| 1118 |
+
/// Overrides the internal iteration index
|
| 1119 |
+
CUTLASS_HOST_DEVICE
|
| 1120 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 1121 |
+
|
| 1122 |
+
/// Adds a pointer offset in units of Element
|
| 1123 |
+
CUTLASS_HOST_DEVICE
|
| 1124 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1125 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 1126 |
+
}
|
| 1127 |
+
|
| 1128 |
+
/// Returns a pointer
|
| 1129 |
+
CUTLASS_HOST_DEVICE
|
| 1130 |
+
AccessType *get() const {
|
| 1131 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 1132 |
+
}
|
| 1133 |
+
|
| 1134 |
+
/// Adds a tile offset
|
| 1135 |
+
CUTLASS_DEVICE
|
| 1136 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 1137 |
+
iterator_.add_tile_offset({coord.column(), coord.row()});
|
| 1138 |
+
}
|
| 1139 |
+
|
| 1140 |
+
/// Advances to the next tile in memory.
|
| 1141 |
+
CUTLASS_HOST_DEVICE
|
| 1142 |
+
RegularTileAccessIterator &operator++() {
|
| 1143 |
+
++iterator_;
|
| 1144 |
+
return *this;
|
| 1145 |
+
}
|
| 1146 |
+
|
| 1147 |
+
/// Advances to the next tile in memory.
|
| 1148 |
+
CUTLASS_HOST_DEVICE
|
| 1149 |
+
RegularTileAccessIterator operator++(int) {
|
| 1150 |
+
RegularTileAccessIterator prev(*this);
|
| 1151 |
+
++iterator_;
|
| 1152 |
+
|
| 1153 |
+
return prev;
|
| 1154 |
+
}
|
| 1155 |
+
};
|
| 1156 |
+
|
| 1157 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1158 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1159 |
+
|
| 1160 |
+
/// Tile iterator specialized for congruous arrangements for TensorOps
|
| 1161 |
+
///
|
| 1162 |
+
///
|
| 1163 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 1164 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 1165 |
+
/// WriteableContiguousTileIteratorConcept
|
| 1166 |
+
///
|
| 1167 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 1168 |
+
typename ThreadMap_, int Alignment>
|
| 1169 |
+
class RegularTileAccessIterator<
|
| 1170 |
+
Shape_, Element_,
|
| 1171 |
+
layout::TensorOpMultiplicandCrosswise128x4,
|
| 1172 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 1173 |
+
public:
|
| 1174 |
+
static_assert(
|
| 1175 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 1176 |
+
"Specialization for pitch-linear iterator may along advance along the "
|
| 1177 |
+
"contiguous(rank=0) or strided(rank=1) dimension.");
|
| 1178 |
+
|
| 1179 |
+
using Shape = Shape_;
|
| 1180 |
+
using Element = Element_;
|
| 1181 |
+
using Layout = layout::TensorOpMultiplicandCrosswise128x4;
|
| 1182 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 1183 |
+
static int const kAlignment = Alignment;
|
| 1184 |
+
|
| 1185 |
+
using Index = typename Layout::Index;
|
| 1186 |
+
using LongIndex = typename Layout::LongIndex;
|
| 1187 |
+
using StrideIndex = typename Layout::Stride::Index;
|
| 1188 |
+
|
| 1189 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 1190 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 1191 |
+
|
| 1192 |
+
using ThreadMap = ThreadMap_;
|
| 1193 |
+
|
| 1194 |
+
static_assert(ThreadMap::kThreads / 32 > 1,
|
| 1195 |
+
"This tile iterator requires at least two warps.");
|
| 1196 |
+
|
| 1197 |
+
/// Internal details made public to facilitate introspection
|
| 1198 |
+
struct Detail {
|
| 1199 |
+
/// This iterator is specialized for an access size that is 128 bits in
|
| 1200 |
+
/// length.
|
| 1201 |
+
static int const kAccessSizeInBits = 128;
|
| 1202 |
+
|
| 1203 |
+
static_assert(sizeof_bits<Element_>::value *
|
| 1204 |
+
ThreadMap::kElementsPerAccess ==
|
| 1205 |
+
kAccessSizeInBits,
|
| 1206 |
+
"This iterator requires a policy whose access size is 128b");
|
| 1207 |
+
|
| 1208 |
+
///< Number of pointers
|
| 1209 |
+
static int const kPointerCount = 1;
|
| 1210 |
+
};
|
| 1211 |
+
|
| 1212 |
+
|
| 1213 |
+
static_assert(!(ThreadMap::Iterations::kStrided % 2), "This iterator requires at least two iterations along the strided dimension");
|
| 1214 |
+
|
| 1215 |
+
/// Element type per access
|
| 1216 |
+
using AccessType = Array<Element, Layout::kElementsPerAccess>;
|
| 1217 |
+
|
| 1218 |
+
private:
|
| 1219 |
+
//
|
| 1220 |
+
// Data members
|
| 1221 |
+
//
|
| 1222 |
+
|
| 1223 |
+
/// Stride value
|
| 1224 |
+
StrideIndex stride_;
|
| 1225 |
+
|
| 1226 |
+
/// Internal pointer to first access of tile
|
| 1227 |
+
AccessType *pointer_;
|
| 1228 |
+
|
| 1229 |
+
/// Internal byte offset
|
| 1230 |
+
Index byte_offset_;
|
| 1231 |
+
|
| 1232 |
+
/// Iteration in the contiguous dimension
|
| 1233 |
+
int iteration_contiguous_;
|
| 1234 |
+
|
| 1235 |
+
/// Iteration in the strided dimension
|
| 1236 |
+
int iteration_strided_;
|
| 1237 |
+
|
| 1238 |
+
public:
|
| 1239 |
+
|
| 1240 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 1241 |
+
CUTLASS_DEVICE
|
| 1242 |
+
RegularTileAccessIterator(
|
| 1243 |
+
TensorRef ref, ///< Pointer to start of tensor
|
| 1244 |
+
int thread_id ///< ID of each participating thread
|
| 1245 |
+
):
|
| 1246 |
+
stride_(ref.stride(0) / Layout::kElementsPerAccess),
|
| 1247 |
+
byte_offset_(0) {
|
| 1248 |
+
|
| 1249 |
+
layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id);
|
| 1250 |
+
|
| 1251 |
+
// This is the offset of a thread within a threadblock tile for a specific
|
| 1252 |
+
// pointer (units of elements)
|
| 1253 |
+
layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base;
|
| 1254 |
+
|
| 1255 |
+
// initialize pointer
|
| 1256 |
+
pointer_ = reinterpret_cast<AccessType *>(ref.data() + ref.offset(thread_offset_in_threadblock_tile));
|
| 1257 |
+
|
| 1258 |
+
set_iteration_index(0);
|
| 1259 |
+
}
|
| 1260 |
+
|
| 1261 |
+
/// Overrides the internal iteration index
|
| 1262 |
+
CUTLASS_HOST_DEVICE
|
| 1263 |
+
void set_iteration_index(int index) {
|
| 1264 |
+
|
| 1265 |
+
iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous;
|
| 1266 |
+
iteration_strided_ = index / ThreadMap::Iterations::kContiguous;
|
| 1267 |
+
}
|
| 1268 |
+
|
| 1269 |
+
/// Adds a pointer offset in units of Element
|
| 1270 |
+
CUTLASS_HOST_DEVICE
|
| 1271 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1272 |
+
|
| 1273 |
+
byte_offset_ += pointer_offset * sizeof(Element);
|
| 1274 |
+
}
|
| 1275 |
+
|
| 1276 |
+
/// Returns a pointer
|
| 1277 |
+
CUTLASS_HOST_DEVICE
|
| 1278 |
+
AccessType *get() const {
|
| 1279 |
+
|
| 1280 |
+
AccessType *access_ptr = pointer_;
|
| 1281 |
+
|
| 1282 |
+
int offset_c = (iteration_contiguous_ * ThreadMap::Delta::kContiguous + (iteration_strided_ & 1) * 2);
|
| 1283 |
+
int offset_s = (iteration_strided_ / 2) * 8;
|
| 1284 |
+
|
| 1285 |
+
int access_offset = offset_c * stride_ + offset_s;
|
| 1286 |
+
|
| 1287 |
+
char *access_byte_ptr =
|
| 1288 |
+
reinterpret_cast<char *>(access_ptr + access_offset);
|
| 1289 |
+
|
| 1290 |
+
return reinterpret_cast<AccessType *>(access_byte_ptr + byte_offset_);
|
| 1291 |
+
}
|
| 1292 |
+
|
| 1293 |
+
/// Advances to the next tile in memory.
|
| 1294 |
+
CUTLASS_HOST_DEVICE
|
| 1295 |
+
RegularTileAccessIterator &operator++() {
|
| 1296 |
+
++iteration_contiguous_;
|
| 1297 |
+
|
| 1298 |
+
if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous)
|
| 1299 |
+
return *this;
|
| 1300 |
+
|
| 1301 |
+
// Enter here only if (iteration_contiguous_ ==
|
| 1302 |
+
// ThreadMap::Iteration::kContiguous)
|
| 1303 |
+
iteration_contiguous_ = 0;
|
| 1304 |
+
++iteration_strided_;
|
| 1305 |
+
|
| 1306 |
+
if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
|
| 1307 |
+
return *this;
|
| 1308 |
+
}
|
| 1309 |
+
|
| 1310 |
+
// Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
|
| 1311 |
+
// which means we enter the next tile.
|
| 1312 |
+
iteration_strided_ = 0;
|
| 1313 |
+
|
| 1314 |
+
return *this;
|
| 1315 |
+
}
|
| 1316 |
+
|
| 1317 |
+
/// Advances to the next tile in memory.
|
| 1318 |
+
CUTLASS_HOST_DEVICE
|
| 1319 |
+
RegularTileAccessIterator operator++(int) {
|
| 1320 |
+
|
| 1321 |
+
RegularTileAccessIterator prev(*this);
|
| 1322 |
+
|
| 1323 |
+
this->operator++();
|
| 1324 |
+
|
| 1325 |
+
return prev;
|
| 1326 |
+
}
|
| 1327 |
+
|
| 1328 |
+
/// Adds a tile offset
|
| 1329 |
+
CUTLASS_DEVICE
|
| 1330 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 1331 |
+
|
| 1332 |
+
add_pointer_offset(
|
| 1333 |
+
coord.contiguous() * Shape::kContiguous * stride_ +
|
| 1334 |
+
coord.strided() * Shape::kStrided * Layout::kElementsPerAccess);
|
| 1335 |
+
}
|
| 1336 |
+
};
|
| 1337 |
+
|
| 1338 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1339 |
+
|
| 1340 |
+
/// Tile Iterator specialized for column-major congruous TensorOp formats.
|
| 1341 |
+
///
|
| 1342 |
+
///
|
| 1343 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 1344 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 1345 |
+
/// WriteableContiguousTileIteratorConcept
|
| 1346 |
+
///
|
| 1347 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 1348 |
+
typename ThreadMap_, int Alignment>
|
| 1349 |
+
class RegularTileAccessIterator<
|
| 1350 |
+
Shape_, Element_,
|
| 1351 |
+
layout::ColumnMajorTensorOpMultiplicandCrosswise128x4,
|
| 1352 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 1353 |
+
public:
|
| 1354 |
+
static_assert(
|
| 1355 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 1356 |
+
"Specialization for column-major iterator may along advance along the "
|
| 1357 |
+
"columns(rank=0) or rows(rank=1) dimension.");
|
| 1358 |
+
|
| 1359 |
+
using Shape = Shape_;
|
| 1360 |
+
using Element = Element_;
|
| 1361 |
+
using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise128x4;
|
| 1362 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 1363 |
+
static int const kAlignment = Alignment;
|
| 1364 |
+
|
| 1365 |
+
using Index = typename Layout::Index;
|
| 1366 |
+
using LongIndex = typename Layout::LongIndex;
|
| 1367 |
+
|
| 1368 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 1369 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 1370 |
+
|
| 1371 |
+
using ThreadMap = ThreadMap_;
|
| 1372 |
+
|
| 1373 |
+
/// Underlying iterator type
|
| 1374 |
+
using UnderlyingIterator = RegularTileAccessIterator<
|
| 1375 |
+
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>, Element,
|
| 1376 |
+
layout::TensorOpMultiplicandCrosswise128x4,
|
| 1377 |
+
(kAdvanceRank == 0 ? 0 : 1), ThreadMap_>;
|
| 1378 |
+
|
| 1379 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 1380 |
+
|
| 1381 |
+
private:
|
| 1382 |
+
/// Underlying iterator
|
| 1383 |
+
UnderlyingIterator iterator_;
|
| 1384 |
+
|
| 1385 |
+
public:
|
| 1386 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 1387 |
+
CUTLASS_HOST_DEVICE
|
| 1388 |
+
RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor
|
| 1389 |
+
int thread_id ///< ID of each participating thread
|
| 1390 |
+
)
|
| 1391 |
+
: iterator_({ref.data(), ref.stride()}, thread_id) {}
|
| 1392 |
+
|
| 1393 |
+
/// Overrides the internal iteration index
|
| 1394 |
+
CUTLASS_HOST_DEVICE
|
| 1395 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 1396 |
+
|
| 1397 |
+
/// Adds a pointer offset in units of Element
|
| 1398 |
+
CUTLASS_HOST_DEVICE
|
| 1399 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1400 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 1401 |
+
}
|
| 1402 |
+
|
| 1403 |
+
/// Returns a pointer
|
| 1404 |
+
CUTLASS_HOST_DEVICE
|
| 1405 |
+
AccessType *get() const {
|
| 1406 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 1407 |
+
}
|
| 1408 |
+
|
| 1409 |
+
/// Adds a tile offset
|
| 1410 |
+
CUTLASS_DEVICE
|
| 1411 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 1412 |
+
iterator_.add_tile_offset({coord.row(), coord.column()});
|
| 1413 |
+
}
|
| 1414 |
+
|
| 1415 |
+
/// Advances to the next tile in memory.
|
| 1416 |
+
CUTLASS_HOST_DEVICE
|
| 1417 |
+
RegularTileAccessIterator &operator++() {
|
| 1418 |
+
++iterator_;
|
| 1419 |
+
return *this;
|
| 1420 |
+
}
|
| 1421 |
+
|
| 1422 |
+
/// Advances to the next tile in memory.
|
| 1423 |
+
CUTLASS_HOST_DEVICE
|
| 1424 |
+
RegularTileAccessIterator operator++(int) {
|
| 1425 |
+
RegularTileAccessIterator prev(*this);
|
| 1426 |
+
++iterator_;
|
| 1427 |
+
|
| 1428 |
+
return prev;
|
| 1429 |
+
}
|
| 1430 |
+
};
|
| 1431 |
+
|
| 1432 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1433 |
+
|
| 1434 |
+
/// Tile Iterator specialized for row-major congruous TensorOp formats.
|
| 1435 |
+
///
|
| 1436 |
+
///
|
| 1437 |
+
/// Satisfies: ForwardTileIteratorConcept |
|
| 1438 |
+
/// ReadableContiguousTileIteratorConcept |
|
| 1439 |
+
/// WriteableContiguousTileIteratorConcept
|
| 1440 |
+
///
|
| 1441 |
+
template <typename Shape_, typename Element_, int AdvanceRank,
|
| 1442 |
+
typename ThreadMap_, int Alignment>
|
| 1443 |
+
class RegularTileAccessIterator<Shape_, Element_,
|
| 1444 |
+
layout::RowMajorTensorOpMultiplicandCrosswise128x4,
|
| 1445 |
+
AdvanceRank, ThreadMap_, Alignment> {
|
| 1446 |
+
public:
|
| 1447 |
+
static_assert(
|
| 1448 |
+
AdvanceRank == 0 || AdvanceRank == 1,
|
| 1449 |
+
"Specialization for row-major iterator may along advance along the "
|
| 1450 |
+
"columns(rank=0) or rows(rank=1) dimension.");
|
| 1451 |
+
|
| 1452 |
+
using Shape = Shape_;
|
| 1453 |
+
using Element = Element_;
|
| 1454 |
+
using Layout = layout::RowMajorTensorOpMultiplicandCrosswise128x4;
|
| 1455 |
+
static int const kAdvanceRank = AdvanceRank;
|
| 1456 |
+
static int const kAlignment = Alignment;
|
| 1457 |
+
|
| 1458 |
+
using Index = typename Layout::Index;
|
| 1459 |
+
using LongIndex = typename Layout::LongIndex;
|
| 1460 |
+
|
| 1461 |
+
using TensorRef = TensorRef<Element, Layout>;
|
| 1462 |
+
using TensorCoord = typename Layout::TensorCoord;
|
| 1463 |
+
|
| 1464 |
+
using ThreadMap = ThreadMap_;
|
| 1465 |
+
|
| 1466 |
+
/// Underlying iterator type
|
| 1467 |
+
using UnderlyingIterator = RegularTileAccessIterator<
|
| 1468 |
+
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>, Element,
|
| 1469 |
+
layout::TensorOpMultiplicandCrosswise128x4,
|
| 1470 |
+
(kAdvanceRank == 0 ? 1 : 0), ThreadMap_>;
|
| 1471 |
+
|
| 1472 |
+
using AccessType = typename UnderlyingIterator::AccessType;
|
| 1473 |
+
|
| 1474 |
+
private:
|
| 1475 |
+
/// Underlying iterator
|
| 1476 |
+
UnderlyingIterator iterator_;
|
| 1477 |
+
|
| 1478 |
+
public:
|
| 1479 |
+
/// Construct a TileIterator with zero threadblock offset
|
| 1480 |
+
CUTLASS_HOST_DEVICE
|
| 1481 |
+
RegularTileAccessIterator(
|
| 1482 |
+
TensorRef ref, ///< Pointer to start of tensor
|
| 1483 |
+
int thread_id ///< ID of each participating thread
|
| 1484 |
+
):
|
| 1485 |
+
iterator_({ref.data(), ref.stride()}, thread_id) {}
|
| 1486 |
+
|
| 1487 |
+
/// Overrides the internal iteration index
|
| 1488 |
+
CUTLASS_HOST_DEVICE
|
| 1489 |
+
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
|
| 1490 |
+
|
| 1491 |
+
/// Adds a pointer offset in units of Element
|
| 1492 |
+
CUTLASS_HOST_DEVICE
|
| 1493 |
+
void add_pointer_offset(LongIndex pointer_offset) {
|
| 1494 |
+
iterator_.add_pointer_offset(pointer_offset);
|
| 1495 |
+
}
|
| 1496 |
+
|
| 1497 |
+
/// Returns a pointer
|
| 1498 |
+
CUTLASS_HOST_DEVICE
|
| 1499 |
+
AccessType *get() const {
|
| 1500 |
+
return reinterpret_cast<AccessType *>(iterator_.get());
|
| 1501 |
+
}
|
| 1502 |
+
|
| 1503 |
+
/// Adds a tile offset
|
| 1504 |
+
CUTLASS_DEVICE
|
| 1505 |
+
void add_tile_offset(TensorCoord const &coord) {
|
| 1506 |
+
iterator_.add_tile_offset({coord.column(), coord.row()});
|
| 1507 |
+
}
|
| 1508 |
+
|
| 1509 |
+
/// Advances to the next tile in memory.
|
| 1510 |
+
CUTLASS_HOST_DEVICE
|
| 1511 |
+
RegularTileAccessIterator &operator++() {
|
| 1512 |
+
++iterator_;
|
| 1513 |
+
return *this;
|
| 1514 |
+
}
|
| 1515 |
+
|
| 1516 |
+
/// Advances to the next tile in memory.
|
| 1517 |
+
CUTLASS_HOST_DEVICE
|
| 1518 |
+
RegularTileAccessIterator operator++(int) {
|
| 1519 |
+
RegularTileAccessIterator prev(*this);
|
| 1520 |
+
++iterator_;
|
| 1521 |
+
|
| 1522 |
+
return prev;
|
| 1523 |
+
}
|
| 1524 |
+
};
|
| 1525 |
+
|
| 1526 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1527 |
+
|
| 1528 |
+
} // namespace threadblock
|
| 1529 |
+
} // namespace transform
|
| 1530 |
+
} // namespace cutlass
|
| 1531 |
+
|
| 1532 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|